diff --git a/third_party/xllm_atb_layers b/third_party/xllm_atb_layers index d6aa214ce..96d3deb21 160000 --- a/third_party/xllm_atb_layers +++ b/third_party/xllm_atb_layers @@ -1 +1 @@ -Subproject commit d6aa214ce69acac8a3061ee8f0ef48b94dd3f5f6 +Subproject commit 96d3deb210b60479bfe2c58bf871a70649129303 diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 3e86b572f..2c435bb35 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -68,7 +68,7 @@ struct OneRecModelInputParams { torch::Tensor cross_attn_kv_cu_seq_lens; torch::Tensor cross_attn_new_cache_slots; torch::Tensor cross_attn_block_tables; - std::vector cross_attn_kv_cu_seq_lens_vec; + std::vector cross_attn_kv_cu_seq_lens_vec; torch::Tensor encoder_token_ids; torch::Tensor encoder_positions; diff --git a/xllm/core/framework/tokenizer/rec_tokenizer.h b/xllm/core/framework/tokenizer/rec_tokenizer.h index 41b03f0bd..f01076b1e 100644 --- a/xllm/core/framework/tokenizer/rec_tokenizer.h +++ b/xllm/core/framework/tokenizer/rec_tokenizer.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "common/types.h" #include "tokenizer.h" #include "tokenizer_args.h" #include "util/slice.h" diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp index 6ea2148f7..421a86808 100644 --- a/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.cpp @@ -34,6 +34,11 @@ static constexpr uint64_t kOneRecWeightCountPerLayer = 79; // Decoder MoE mode weights count (exclude runtime tensors like expert_array). static constexpr uint64_t kOneRecMoeWeightCountPerLayer = 97; +// OneRec attention linear best-practice defaults. +// Keep them local to avoid exposing extra user-facing flags. +static constexpr bool kEnableOneRecAclnnAttentionLinear = true; +static constexpr int32_t kOneRecAclnnAttentionLinearMinTokens = 128; + enum class OneRecBlockLayerTensorId : int32_t { // Self-attention layer norm IN_LAYER_NORM_WEIGHT = 0, @@ -528,7 +533,9 @@ get_onerec_decoder_moe_weight_mapping() { kOneRecDecoderWeightMapping; mapping.emplace("layer.2.ffn.gate.weight", kInBlockSparseMoeGateWeight); + mapping.emplace("layer.2.ffn.router.weight", kInBlockSparseMoeGateWeight); mapping.emplace("2.ffn.gate.weight", kInBlockSparseMoeGateWeight); + mapping.emplace("2.ffn.router.weight", kInBlockSparseMoeGateWeight); mapping.emplace("layer.2.ffn.shared_experts.w1.weight", kInMlpGateUpWeightSharedExpert); @@ -594,7 +601,23 @@ NpuOneRecBlockLayerImpl::NpuOneRecBlockLayerImpl(const ModelContext& context, const auto& args = context.get_model_args(); const auto& parallel_args = context.get_parallel_args(); param_from_args(prefill_param_, args, parallel_args, /*is_prefill=*/true); + param_from_args(prefill_param_atb_, args, parallel_args, /*is_prefill=*/true); + prefill_param_atb_.matmulBackend = atb_speed::common::OpBackend::ATB; param_from_args(decode_param_, args, parallel_args, /*is_prefill=*/false); + if (FLAGS_enable_rec_prefill_only && is_decoder_) { + param_from_args(decoder_prefill_only_decode_param_, + args, + parallel_args, + /*is_prefill=*/true); + decoder_prefill_only_decode_param_.emptyCrossAttn = false; + param_from_args(decoder_prefill_only_decode_param_atb_, + args, + parallel_args, + /*is_prefill=*/true); + decoder_prefill_only_decode_param_atb_.emptyCrossAttn = false; + decoder_prefill_only_decode_param_atb_.matmulBackend = + atb_speed::common::OpBackend::ATB; + } const int32_t weight_count = prefill_param_.use_moe ? kOneRecMoeWeightCountPerLayer @@ -636,6 +659,10 @@ void NpuOneRecBlockLayerImpl::param_from_args( param.isBF16 = args.dtype() == "bfloat16"; param.isPack = true; param.supportSwiGLU = true; + // Shared experts in the current OneRec MoE path are loaded as bf16/fp + // weights. Do not force the dedicated SwigluQuant scale contract unless the + // shared expert path is explicitly wired for dynamic quant. + param.enableSwiGLUQuantForSharedExperts = false; param.supportLcoc = is_prefill; param.supportSpeculate = false; param.enableSplitFuse = FLAGS_enable_chunked_prefill && is_prefill; @@ -649,6 +676,9 @@ void NpuOneRecBlockLayerImpl::param_from_args( param.isOneRecEncoder = !is_decoder_; param.enableOneRecPrefillOnly = FLAGS_enable_rec_prefill_only; param.backend = FLAGS_communication_backend; + param.matmulBackend = kEnableOneRecAclnnAttentionLinear + ? atb_speed::common::OpBackend::ACLNN + : atb_speed::common::OpBackend::ATB; param.rank = parallel_args.rank(); param.worldSize = parallel_args.world_size(); param.quantType = 0; @@ -660,12 +690,21 @@ void NpuOneRecBlockLayerImpl::param_from_args( is_decoder_ ? args.decoder_head_dim() : args.head_dim(); param.numAttentionHeadsPerRank = args_n_heads / param.worldSize; param.hiddenSizePerAttentionHead = args_head_dim; - - std::optional optional_value = - is_decoder_ ? args.decoder_n_kv_heads().value_or(args.decoder_n_heads()) - : args.n_kv_heads().value_or(args.n_heads()); + // Reuse an existing model capability bit to split the legacy and 3B paths. + // Current validated models follow: + // - legacy model: moe_use_shared_experts = false + // - 3B model: moe_use_shared_experts = true + param.useAttentionScaling = args.moe_use_shared_experts(); + + const auto general_kv_heads = args.n_kv_heads(); + const auto decoder_kv_heads = args.decoder_n_kv_heads().has_value() + ? args.decoder_n_kv_heads() + : general_kv_heads; + const int64_t args_kv_heads = + is_decoder_ ? decoder_kv_heads.value_or(args.decoder_n_heads()) + : general_kv_heads.value_or(args.n_heads()); param.numKeyValueHeadsPerRank = - static_cast(optional_value.value()) / param.worldSize; + static_cast(args_kv_heads / param.worldSize); param.rmsNormEps = args.rms_norm_eps(); param.seqLen = {}; @@ -713,6 +752,7 @@ void NpuOneRecBlockLayerImpl::param_from_args( param.moe_config->hasSharedExpertGate = false; param.moe_config->moe_use_shared_experts = args.moe_use_shared_experts(); param.moe_config->moe_num_shared_experts = args.n_shared_experts(); + param.moe_config->enable_integrated_softmax_topk = true; param.moeLinearQuantType = {atb_speed::common::LinearType::FP, atb_speed::common::LinearType::FP, @@ -723,13 +763,36 @@ void NpuOneRecBlockLayerImpl::param_from_args( void NpuOneRecBlockLayerImpl::verify_loaded_weights( const std::string& prefix) const { - const auto& weight_mapping = - [this]() -> const std::unordered_map& { + std::unordered_map filtered_weight_mapping; + const auto* weight_mapping = [this, &filtered_weight_mapping]() + -> const std::unordered_map* { if (prefill_param_.use_moe) { - return kOneRecDecoderMoeWeightMapping; + filtered_weight_mapping.clear(); + const bool has_shared_experts = + prefill_param_.moe_config != nullptr && + prefill_param_.moe_config->moe_use_shared_experts; + const bool has_shared_expert_gate = + prefill_param_.moe_config != nullptr && + prefill_param_.moe_config->hasSharedExpertGate; + for (const auto& [name, index] : kOneRecDecoderMoeWeightMapping) { + bool should_include = true; + if (!has_shared_experts && + name.find("shared_expert") != std::string::npos) { + should_include = false; + } + if (should_include && !has_shared_expert_gate && + (name.find("shared_expert.gate") != std::string::npos || + name.find("shared_expert_gate") != std::string::npos)) { + should_include = false; + } + if (should_include) { + filtered_weight_mapping.emplace(name, index); + } + } + return &filtered_weight_mapping; } - return is_decoder_ ? kOneRecDecoderWeightMapping - : kOneRecEncoderWeightMapping; + return is_decoder_ ? &kOneRecDecoderWeightMapping + : &kOneRecEncoderWeightMapping; }(); // verify_loaded_weights() runs before merge_loaded_weights(). @@ -742,25 +805,12 @@ void NpuOneRecBlockLayerImpl::verify_loaded_weights( allowed_placeholders.insert(kInFfnWi1Weight); allowed_placeholders.insert(kInFfnWoWeight); } - const bool has_shared_experts = - prefill_param_.moe_config != nullptr && - prefill_param_.moe_config->moe_use_shared_experts; - - for (const auto& [name, index] : weight_mapping) { + for (const auto& [name, index] : *weight_mapping) { const auto sizes = at_weight_tensors_[index].sizes(); const bool is_placeholder = (sizes.size() == 2 && sizes[0] == 1); const bool expected_placeholder = allowed_placeholders.count(index) > 0; const bool is_relative_bias = (index == kInRelativeAttentionBiasWeight); - const bool is_shared_optional = prefill_param_.use_moe && - !has_shared_experts && - (index == kInMlpGateUpWeightSharedExpert || - index == kInMlpDownWeightSharedExpert || - index == kInSharedExpertGateWeight || - index == kInSharedExpertGateBias || - index == kInSharedExpertGateOffset || - index == kInSharedExpertGateScale); - if (is_placeholder && !expected_placeholder && !is_relative_bias && - !is_shared_optional) { + if (is_placeholder && !expected_placeholder && !is_relative_bias) { CHECK(false) << "weight is not loaded for " << prefix << name; } } @@ -958,11 +1008,14 @@ void NpuOneRecBlockLayerImpl::load_state_dict(const StateDict& state_dict) { } torch::Tensor mutable_tensor = (shard_dim >= 0 && parallel_args_.world_size() > 1) - ? state_dict.get_sharded_tensor(tensor_name, + ? state_dict.get_sharded_tensor(name, shard_dim, parallel_args_.rank(), parallel_args_.world_size()) : tensor; + if (!mutable_tensor.defined()) { + continue; + } correct_tensor_dtype(mutable_tensor, tensor_name); at_weight_tensors_[weight_position] = mutable_tensor.to(device_); return; @@ -1006,6 +1059,64 @@ void NpuOneRecBlockLayerImpl::load_state_dict(const StateDict& state_dict) { } } + const auto load_dense_fused_ffn_weights = + [this, &state_dict, &correct_tensor_dtype](const std::string& state_key, + const torch::Tensor& tensor) { + if (absl::StrContains(state_key, ".ffn.experts.") || + absl::StrContains(state_key, ".ffn.shared_experts.") || + absl::StrContains(state_key, ".ffn.shared_expert.")) { + return; + } + + if (absl::StrContains(state_key, ".DenseReluDense.weight1") || + absl::StrContains(state_key, ".ffn.weight1")) { + torch::Tensor fused_gate_up = + (parallel_args_.world_size() > 1) + ? state_dict.get_sharded_tensor(state_key, + /*dim=*/0, + parallel_args_.rank(), + parallel_args_.world_size()) + : tensor; + if (!fused_gate_up.defined()) { + return; + } + CHECK_EQ(fused_gate_up.dim(), 2) + << "OneRec fused FFN weight1 must be 2D, got " + << fused_gate_up.sizes() << " from " << state_key; + CHECK_EQ(fused_gate_up.size(0) % 2, 0) + << "OneRec fused FFN weight1 dim0 must be even, got " + << fused_gate_up.sizes() << " from " << state_key; + correct_tensor_dtype(fused_gate_up, state_key); + std::vector chunks = + fused_gate_up.chunk(/*chunks=*/2, /*dim=*/0); + at_weight_tensors_[kInFfnWi0Weight] = + chunks[0].contiguous().to(device_); + at_weight_tensors_[kInFfnWi1Weight] = + chunks[1].contiguous().to(device_); + return; + } + + if (absl::StrContains(state_key, ".DenseReluDense.weight2") || + absl::StrContains(state_key, ".ffn.weight2")) { + torch::Tensor wo_weight = + (parallel_args_.world_size() > 1) + ? state_dict.get_sharded_tensor(state_key, + /*dim=*/1, + parallel_args_.rank(), + parallel_args_.world_size()) + : tensor; + if (!wo_weight.defined()) { + return; + } + correct_tensor_dtype(wo_weight, state_key); + at_weight_tensors_[kInFfnWoWeight] = wo_weight.to(device_); + } + }; + + for (const auto& [state_key, tensor] : state_dict) { + load_dense_fused_ffn_weights(state_key, tensor); + } + std::vector> ordered_mapping( weight_mapping.begin(), weight_mapping.end()); std::sort( @@ -1040,11 +1151,22 @@ int64_t NpuOneRecBlockLayerImpl::init_layer() { is_decoder_ ? "onerec_decoder_block_layer" : "onerec_encoder_block_layer"; model_name_ = "onerec"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); + if (kEnableOneRecAclnnAttentionLinear && + kOneRecAclnnAttentionLinearMinTokens > 0) { + CHECK_OPERATION_STATUS_RETURN( + init_node(prefill_node_atb_, prefill_param_atb_)); + } if (is_decoder_) { if (FLAGS_enable_rec_prefill_only) { - LOG(INFO) << "OneRec BlockLayer init_layer skip decode node because " - "enable_rec_prefill_only is enabled" - << ", layer_id=" << layer_id_; + CHECK_OPERATION_STATUS_RETURN( + init_node(decoder_prefill_only_decode_node_, + decoder_prefill_only_decode_param_)); + if (kEnableOneRecAclnnAttentionLinear && + kOneRecAclnnAttentionLinearMinTokens > 0) { + CHECK_OPERATION_STATUS_RETURN( + init_node(decoder_prefill_only_decode_node_atb_, + decoder_prefill_only_decode_param_atb_)); + } LOG(INFO) << "OneRec BlockLayer init_layer success" << ", layer_role=" << (is_decoder_ ? "decoder" : "encoder") << ", layer_id=" << layer_id_ << ", status=" << atb::NO_ERROR; @@ -1119,20 +1241,77 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( const bool is_prefill = onerec_params->rec_stage == OneRecModelInputParams::RecStage::PREFILL; + const bool is_first_prefill = onerec_params->is_first_prefill; + const int64_t ntokens = x.dim() >= 1 ? x.size(0) : 1; + const bool use_atb_small_tokens = + kEnableOneRecAclnnAttentionLinear && + kOneRecAclnnAttentionLinearMinTokens > 0 && ntokens > 0 && + ntokens < kOneRecAclnnAttentionLinearMinTokens; atb::Status st; if (is_prefill) { if (is_decoder_) { - if (prefill_param_.use_moe) { + if (FLAGS_enable_rec_prefill_only) { + if (is_first_prefill && encoder_output != nullptr) { + const int64_t bs = encoder_output->size(0); + const int64_t seq_len = encoder_output->size(1); + const int64_t kv_hidden_size = + prefill_param_.numKeyValueHeadsPerRank * + prefill_param_.hiddenSizePerAttentionHead; + auto options = torch::TensorOptions() + .dtype(encoder_output->dtype()) + .device(encoder_output->device()); + cross_k_cache_ = torch::empty({bs, seq_len, kv_hidden_size}, options); + cross_v_cache_ = torch::empty({bs, seq_len, kv_hidden_size}, options); + } + + atb_speed::Model::Node& target_node = + is_first_prefill + ? ((use_atb_small_tokens && + prefill_node_atb_.operation != nullptr) + ? prefill_node_atb_ + : prefill_node_) + : ((use_atb_small_tokens && + decoder_prefill_only_decode_node_atb_.operation != nullptr) + ? decoder_prefill_only_decode_node_atb_ + : decoder_prefill_only_decode_node_); + if (prefill_param_.use_moe) { + build_decoder_moe_node_variant_pack( + target_node, + x, + attn_mask, + kv_cache, + input_params, + true, + is_first_prefill, + is_first_prefill ? encoder_output : nullptr, + node_id, + expert_array); + } else { + build_decoder_node_variant_pack( + target_node, + x, + attn_mask, + kv_cache, + input_params, + true, + is_first_prefill, + is_first_prefill ? encoder_output : nullptr, + node_id); + } + st = execute_node(target_node, node_id, event, event_flag); + } else if (prefill_param_.use_moe) { build_decoder_moe_node_variant_pack(prefill_node_, x, attn_mask, kv_cache, input_params, true, + true, encoder_output, node_id, expert_array); + st = execute_node(prefill_node_, node_id, event, event_flag); } else { build_decoder_node_variant_pack(prefill_node_, x, @@ -1140,10 +1319,11 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( kv_cache, input_params, true, + true, encoder_output, node_id); + st = execute_node(prefill_node_, node_id, event, event_flag); } - st = execute_node(prefill_node_, node_id, event, event_flag); LOG_IF(FATAL, st != 0) << model_name_ << " execute prefill layer fail, error code: " << st; } else { @@ -1166,6 +1346,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( kv_cache, input_params, false, + false, encoder_output, node_id, expert_array); @@ -1176,6 +1357,7 @@ torch::Tensor NpuOneRecBlockLayerImpl::forward( kv_cache, input_params, false, + false, encoder_output, node_id); } @@ -1247,10 +1429,10 @@ void NpuOneRecBlockLayerImpl::build_decoder_moe_node_variant_pack( KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill, + bool is_first_prefill, torch::Tensor* encoder_output, int32_t layer_id, const torch::Tensor& expert_array) { - (void)kv_cache; (void)is_prefill; (void)layer_id; @@ -1280,7 +1462,17 @@ void NpuOneRecBlockLayerImpl::build_decoder_moe_node_variant_pack( : placeholder_; int32_t tensor_idx = setup_common_decoder_tensors( - node, x, attn_mask, input_params, encoder_output, moe_tensor_start + 4); + node, + x, + attn_mask, + kv_cache, + input_params, + (FLAGS_enable_rec_prefill_only && is_prefill && !is_first_prefill) + ? decoder_prefill_only_decode_param_ + : (is_prefill ? prefill_param_ : decode_param_), + is_first_prefill, + encoder_output, + moe_tensor_start + 4); while (tensor_idx < static_cast(node.variantPack.inTensors.size())) { node.variantPack.inTensors.at(tensor_idx) = placeholder_; @@ -1294,7 +1486,10 @@ int32_t NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( atb_speed::Model::Node& node, torch::Tensor& x, at::Tensor& attn_mask, + KVCache& kv_cache, ModelInputParams& input_params, + const atb_speed::onerec::BlockLayerParam& param, + bool is_first_prefill, torch::Tensor* encoder_output, int32_t start_tensor_idx) { internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); @@ -1304,27 +1499,44 @@ int32_t NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( node.variantPack.inTensors.at(idx++) = atb_speed::Utils::AtTensor2Tensor(attn_mask); - // Token offset and layer id placeholders. - // ATB expects hostData to be valid for these scalar inputs. Keep them as - // placeholders but always provide hostData to avoid undefined reads. - node.variantPack.inTensors.at(idx) = placeholder_; - node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); - node.variantPack.inTensors.at(idx) = placeholder_; - node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + torch::Tensor k_cache = kv_cache.get_k_cache(); + torch::Tensor v_cache = kv_cache.get_v_cache(); + node.variantPack.inTensors.at(idx++) = + k_cache.defined() ? atb_speed::Utils::AtTensor2Tensor(k_cache) + : placeholder_; + node.variantPack.inTensors.at(idx++) = + v_cache.defined() ? atb_speed::Utils::AtTensor2Tensor(v_cache) + : placeholder_; - CHECK(input_params.kv_seq_lens.defined()) << "kv_seq_lens is required."; - node.variantPack.inTensors.at(idx) = - atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); - node.variantPack.inTensors.at(idx).hostData = - input_params.kv_seq_lens_vec.data(); + if (input_params.kv_seq_lens.defined()) { + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); + node.variantPack.inTensors.at(idx).hostData = + input_params.kv_seq_lens_vec.data(); + } else { + int32_t seq_len = std::max(static_cast(x.size(0)), 1); + seq_lens_vec_ = {seq_len}; + fallback_kv_seq_lens_tensor_ = torch::tensor( + seq_lens_vec_, + torch::TensorOptions().dtype(torch::kInt32).device(device_)); + node.variantPack.inTensors.at(idx) = + atb_speed::Utils::AtTensor2Tensor(fallback_kv_seq_lens_tensor_); + node.variantPack.inTensors.at(idx).hostData = seq_lens_vec_.data(); + } idx++; + // Token offset and layer id placeholders. + // ATB expects hostData to be valid for these scalar inputs. Keep them as + // placeholders but always provide hostData to avoid undefined reads. node.variantPack.inTensors.at(idx) = placeholder_; node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); node.variantPack.inTensors.at(idx) = placeholder_; node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); - if (!FLAGS_enable_rec_prefill_only && input_params.block_tables.defined()) { + // Align with xllm_rec T5 prefill-only path: self-attn block tables are not + // consumed during decoder prefill-only execution, so do not forward the + // runtime empty [bs, 0] tensor to ATB. + if (!param.enableOneRecPrefillOnly && input_params.block_tables.defined()) { node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor(input_params.block_tables); } else { @@ -1333,7 +1545,7 @@ int32_t NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( } idx++; - if (!FLAGS_enable_rec_prefill_only && + if (!param.enableOneRecPrefillOnly && input_params.new_cache_slots.defined()) { node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); @@ -1343,7 +1555,7 @@ int32_t NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( } idx++; - if (encoder_output != nullptr) { + if (is_first_prefill && encoder_output != nullptr) { encoder_output_contiguous_ = encoder_output->is_contiguous() ? *encoder_output : encoder_output->contiguous(); @@ -1354,21 +1566,73 @@ int32_t NpuOneRecBlockLayerImpl::setup_common_decoder_tensors( } idx++; - for (int32_t i = 0; i < 3; i++) { - node.variantPack.inTensors.at(idx) = placeholder_; - node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + const auto* onerec_params = input_params.onerec_params(); + const bool minimize_cross_attn_inputs = + param.enableOneRecPrefillOnly && !param.enableSplitFuse && !param.isFA; + if (!minimize_cross_attn_inputs) { + if (onerec_params != nullptr && + onerec_params->cross_attn_kv_cu_seq_lens.defined()) { + node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor( + onerec_params->cross_attn_kv_cu_seq_lens); + node.variantPack.inTensors.at(idx).hostData = const_cast( + onerec_params->cross_attn_kv_cu_seq_lens_vec.data()); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; + + if (onerec_params != nullptr && + onerec_params->cross_attn_block_tables.defined()) { + node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor( + onerec_params->cross_attn_block_tables); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; + + if (is_first_prefill && onerec_params != nullptr && + onerec_params->cross_attn_new_cache_slots.defined()) { + node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor( + onerec_params->cross_attn_new_cache_slots); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; } - const auto* onerec_params = input_params.onerec_params(); if (onerec_params != nullptr && onerec_params->encoder_seq_lens_tensor.defined()) { node.variantPack.inTensors.at(idx) = atb_speed::Utils::AtTensor2Tensor( onerec_params->encoder_seq_lens_tensor); - node.variantPack.inTensors.at(idx++).hostData = + node.variantPack.inTensors.at(idx).hostData = const_cast(onerec_params->encoder_seq_lens.data()); + } else { + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx).hostData = placeholder_vec_.data(); + } + idx++; + + if (param.enableOneRecPrefillOnly && cross_k_cache_.defined() && + cross_v_cache_.defined()) { + if (is_first_prefill && node.variantPack.outTensors.size() >= 3) { + node.variantPack.outTensors.at(1) = + atb_speed::Utils::AtTensor2Tensor(cross_k_cache_); + node.variantPack.outTensors.at(2) = + atb_speed::Utils::AtTensor2Tensor(cross_v_cache_); + } else if (!is_first_prefill) { + node.variantPack.inTensors.at(idx++) = + atb_speed::Utils::AtTensor2Tensor(cross_k_cache_); + node.variantPack.inTensors.at(idx++) = + atb_speed::Utils::AtTensor2Tensor(cross_v_cache_); + } } else { node.variantPack.inTensors.at(idx) = placeholder_; node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); + node.variantPack.inTensors.at(idx) = placeholder_; + node.variantPack.inTensors.at(idx++).hostData = placeholder_vec_.data(); } node.variantPack.outTensors.at(0) = internal_tensors_; @@ -1382,9 +1646,9 @@ void NpuOneRecBlockLayerImpl::build_decoder_node_variant_pack( KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill, + bool is_first_prefill, torch::Tensor* encoder_output, int32_t layer_id) { - (void)kv_cache; (void)is_prefill; (void)layer_id; @@ -1398,7 +1662,12 @@ void NpuOneRecBlockLayerImpl::build_decoder_node_variant_pack( node, x, attn_mask, + kv_cache, input_params, + (FLAGS_enable_rec_prefill_only && is_prefill && !is_first_prefill) + ? decoder_prefill_only_decode_param_ + : (is_prefill ? prefill_param_ : decode_param_), + is_first_prefill, encoder_output, static_cast(kOneRecWeightCountPerLayer)); while (tensor_idx < static_cast(node.variantPack.inTensors.size())) { @@ -1417,6 +1686,18 @@ void NpuOneRecBlockLayerImpl::resize_experts_weights( std::vector(num_of_device_experts); experts_weights_["down_proj.weight"] = std::vector(num_of_device_experts); + experts_weights_["gate_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_offset"] = + std::vector(num_of_device_experts); + experts_weights_["gate_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["up_proj.weight_scale"] = + std::vector(num_of_device_experts); + experts_weights_["down_proj.weight_scale"] = + std::vector(num_of_device_experts); } void NpuOneRecBlockLayerImpl::process_expert_weights( @@ -1426,6 +1707,57 @@ void NpuOneRecBlockLayerImpl::process_expert_weights( (void)state_dict; std::lock_guard lock(experts_mutex_); + auto unpack_fused_weight1 = [&](const torch::Tensor& fused_weight) { + if (num_experts_per_partition_ <= 0 || fused_weight.dim() != 2) { + LOG(WARNING) << "Invalid OneRec fused expert weight1 shape for " << name + << ": " << fused_weight.sizes(); + return; + } + torch::Tensor reshaped = fused_weight.contiguous().view( + {num_experts_per_partition_, fused_weight.size(0), -1}); + CHECK_EQ(reshaped.size(2) % 2, 0) + << "OneRec fused expert weight1 last dim must be even for " << name + << ", got shape " << fused_weight.sizes(); + std::vector gate_up_chunks = + reshaped.chunk(/*chunks=*/2, /*dim=*/-1); + for (int32_t i = 0; i < num_experts_per_partition_; ++i) { + experts_weights_["gate_proj.weight"][i] = + gate_up_chunks[0][i].transpose(0, 1).contiguous(); + experts_weights_["up_proj.weight"][i] = + gate_up_chunks[1][i].transpose(0, 1).contiguous(); + } + LOG(INFO) << "Unpacked OneRec fused routed expert weight1 into " + << num_experts_per_partition_ + << " gate/up experts, source shape: [" << fused_weight.sizes() + << "]"; + }; + + auto unpack_fused_weight2 = [&](const torch::Tensor& fused_weight) { + if (num_experts_per_partition_ <= 0 || fused_weight.dim() != 2) { + LOG(WARNING) << "Invalid OneRec fused expert weight2 shape for " << name + << ": " << fused_weight.sizes(); + return; + } + torch::Tensor reshaped = fused_weight.contiguous().view( + {num_experts_per_partition_, -1, fused_weight.size(1)}); + for (int32_t i = 0; i < num_experts_per_partition_; ++i) { + experts_weights_["down_proj.weight"][i] = + reshaped[i].transpose(0, 1).contiguous(); + } + LOG(INFO) << "Unpacked OneRec fused routed expert weight2 into " + << num_experts_per_partition_ << " down experts, source shape: [" + << fused_weight.sizes() << "]"; + }; + + if (absl::StrContains(name, ".ffn.experts.weight1")) { + unpack_fused_weight1(tensor); + return; + } + if (absl::StrContains(name, ".ffn.experts.weight2")) { + unpack_fused_weight2(tensor); + return; + } + int32_t expert_id = extract_expert_index(name); if (expert_id < 0) { return; @@ -1443,6 +1775,24 @@ void NpuOneRecBlockLayerImpl::process_expert_weights( } else if (weight_suffix == "down_proj.weight" || weight_suffix == "w2.weight") { suffix = "down_proj.weight"; + } else if (weight_suffix == "gate_proj.weight_offset" || + weight_suffix == "w1.weight_offset") { + suffix = "gate_proj.weight_offset"; + } else if (weight_suffix == "up_proj.weight_offset" || + weight_suffix == "w3.weight_offset") { + suffix = "up_proj.weight_offset"; + } else if (weight_suffix == "down_proj.weight_offset" || + weight_suffix == "w2.weight_offset") { + suffix = "down_proj.weight_offset"; + } else if (weight_suffix == "gate_proj.weight_scale" || + weight_suffix == "w1.weight_scale") { + suffix = "gate_proj.weight_scale"; + } else if (weight_suffix == "up_proj.weight_scale" || + weight_suffix == "w3.weight_scale") { + suffix = "up_proj.weight_scale"; + } else if (weight_suffix == "down_proj.weight_scale" || + weight_suffix == "w2.weight_scale") { + suffix = "down_proj.weight_scale"; } else { return; } @@ -1581,58 +1931,135 @@ void NpuOneRecBlockLayerImpl::merge_experts_weights() { at_npu::native::npu_format_cast(merged_gate_up, /*format=*/2) .contiguous(); + if (quantize_type_ == "w8a8_dynamic") { + if (experts_weights_.count("gate_proj.weight_offset") > 0 && + experts_weights_.count("up_proj.weight_offset") > 0) { + std::vector gate_offset_1d; + std::vector up_offset_1d; + gate_offset_1d.reserve( + experts_weights_["gate_proj.weight_offset"].size()); + up_offset_1d.reserve(experts_weights_["up_proj.weight_offset"].size()); + for (const auto& tensor : experts_weights_["gate_proj.weight_offset"]) { + if (tensor.defined()) { + gate_offset_1d.emplace_back(tensor); + } + } + for (const auto& tensor : experts_weights_["up_proj.weight_offset"]) { + if (tensor.defined()) { + up_offset_1d.emplace_back(tensor); + } + } + if (!gate_offset_1d.empty() && + gate_offset_1d.size() == up_offset_1d.size()) { + at_weight_tensors_[kInMlpGateUpOffsetExpert] = + merge_experts_weights(gate_offset_1d, + up_offset_1d, + /*transpose=*/false); + } + } + if (experts_weights_.count("gate_proj.weight_scale") > 0 && + experts_weights_.count("up_proj.weight_scale") > 0) { + std::vector gate_scale_1d; + std::vector up_scale_1d; + gate_scale_1d.reserve(experts_weights_["gate_proj.weight_scale"].size()); + up_scale_1d.reserve(experts_weights_["up_proj.weight_scale"].size()); + for (const auto& tensor : experts_weights_["gate_proj.weight_scale"]) { + if (tensor.defined()) { + gate_scale_1d.emplace_back(tensor); + } + } + for (const auto& tensor : experts_weights_["up_proj.weight_scale"]) { + if (tensor.defined()) { + up_scale_1d.emplace_back(tensor); + } + } + if (!gate_scale_1d.empty() && + gate_scale_1d.size() == up_scale_1d.size()) { + at_weight_tensors_[kInMlpGateUpScaleExpert] = + merge_experts_weights(gate_scale_1d, + up_scale_1d, + /*transpose=*/false); + } + } + } + auto merged_down = merge_experts_weights(experts_weights_["down_proj.weight"], /*transpose=*/false); CHECK(merged_down.defined()) << "OneRec MoE down experts merge failed."; at_weight_tensors_[kInMoeExpertW2Weight] = at_npu::native::npu_format_cast(merged_down, /*format=*/2).contiguous(); + + if (quantize_type_ == "w8a8_dynamic") { + if (experts_weights_.count("down_proj.weight_offset") > 0) { + std::vector down_offset_1d; + down_offset_1d.reserve( + experts_weights_["down_proj.weight_offset"].size()); + for (const auto& tensor : experts_weights_["down_proj.weight_offset"]) { + if (tensor.defined()) { + down_offset_1d.emplace_back(tensor); + } + } + if (!down_offset_1d.empty()) { + at_weight_tensors_[kInMlpDownOffsetExpert] = + merge_experts_weights(down_offset_1d, /*transpose=*/false); + } + } + if (experts_weights_.count("down_proj.weight_scale") > 0) { + std::vector down_scale_1d; + down_scale_1d.reserve(experts_weights_["down_proj.weight_scale"].size()); + for (const auto& tensor : experts_weights_["down_proj.weight_scale"]) { + if (tensor.defined()) { + down_scale_1d.emplace_back(tensor); + } + } + if (!down_scale_1d.empty()) { + at_weight_tensors_[kInMlpDownScaleExpert] = + merge_experts_weights(down_scale_1d, /*transpose=*/false); + } + } + } } void NpuOneRecBlockLayerImpl::merge_shared_experts_weights() { - shared_expert_gate_weights_.clear(); - shared_expert_up_weights_.clear(); - shared_expert_down_weights_.clear(); + auto get_shared_weight = [this](const char* key) -> torch::Tensor { + auto it = shared_expert_weights_map_.find(key); + return it == shared_expert_weights_map_.end() ? torch::Tensor() + : it->second; + }; - if (const auto it = shared_expert_weights_map_.find("gate_proj.weight"); - it != shared_expert_weights_map_.end()) { - shared_expert_gate_weights_.push_back(it->second); - } - if (const auto it = shared_expert_weights_map_.find("up_proj.weight"); - it != shared_expert_weights_map_.end()) { - shared_expert_up_weights_.push_back(it->second); - } - if (const auto it = shared_expert_weights_map_.find("down_proj.weight"); - it != shared_expert_weights_map_.end()) { - shared_expert_down_weights_.push_back(it->second); - } + torch::Tensor gate_weight = get_shared_weight("gate_proj.weight"); + torch::Tensor up_weight = get_shared_weight("up_proj.weight"); + torch::Tensor down_weight = get_shared_weight("down_proj.weight"); - if (shared_expert_gate_weights_.empty() && - shared_expert_up_weights_.empty() && - shared_expert_down_weights_.empty()) { + if (!gate_weight.defined() && !up_weight.defined() && + !down_weight.defined()) { return; } - if (!shared_expert_gate_weights_.empty() && - !shared_expert_up_weights_.empty()) { - auto merged_gate_up = merge_experts_weights(shared_expert_gate_weights_, - shared_expert_up_weights_, - /*transpose=*/false); - CHECK(merged_gate_up.defined()) - << "OneRec shared gate/up experts merge failed at layer " << layer_id_; - at_weight_tensors_[kInMlpGateUpWeightSharedExpert] = merged_gate_up; - } else if (!shared_expert_gate_weights_.empty()) { + auto prepare_shared_weight_2d = [&](const torch::Tensor& weight, + const char* tag) { + CHECK(weight.defined()) << "No OneRec shared expert weights for " << tag; + torch::Tensor merged = weight.to(device_).contiguous(); + CHECK_EQ(merged.dim(), 2) << "OneRec shared expert " << tag + << " must stay 2D, got " << merged.sizes(); + return merged; + }; + + if (gate_weight.defined() && up_weight.defined()) { + torch::Tensor merged_gate = prepare_shared_weight_2d(gate_weight, "gate"); + torch::Tensor merged_up = prepare_shared_weight_2d(up_weight, "up"); + at_weight_tensors_[kInMlpGateUpWeightSharedExpert] = + torch::cat({merged_gate, merged_up}, /*dim=*/0).contiguous(); + } else if (gate_weight.defined()) { at_weight_tensors_[kInMlpGateUpWeightSharedExpert] = - merge_experts_weights(shared_expert_gate_weights_, false); + prepare_shared_weight_2d(gate_weight, "gate_only"); } - if (!shared_expert_down_weights_.empty()) { + if (down_weight.defined()) { at_weight_tensors_[kInMlpDownWeightSharedExpert] = - merge_experts_weights(shared_expert_down_weights_, false); + prepare_shared_weight_2d(down_weight, "down"); } - shared_expert_gate_weights_.clear(); - shared_expert_up_weights_.clear(); - shared_expert_down_weights_.clear(); shared_expert_weights_map_.clear(); } diff --git a/xllm/core/layers/npu/npu_onerec_block_layer_impl.h b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h index ab21afe62..908f249ab 100644 --- a/xllm/core/layers/npu/npu_onerec_block_layer_impl.h +++ b/xllm/core/layers/npu/npu_onerec_block_layer_impl.h @@ -84,6 +84,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill, + bool is_first_prefill, torch::Tensor* encoder_output = nullptr, int32_t layer_id = 0); @@ -94,6 +95,7 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill, + bool is_first_prefill, torch::Tensor* encoder_output = nullptr, int32_t layer_id = 0, const torch::Tensor& expert_array = torch::Tensor()); @@ -103,12 +105,16 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { int64_t init_attn_mask(); - int32_t setup_common_decoder_tensors(atb_speed::Model::Node& node, - torch::Tensor& x, - at::Tensor& attn_mask, - ModelInputParams& input_params, - torch::Tensor* encoder_output = nullptr, - int32_t start_tensor_idx = 0); + int32_t setup_common_decoder_tensors( + atb_speed::Model::Node& node, + torch::Tensor& x, + at::Tensor& attn_mask, + KVCache& kv_cache, + ModelInputParams& input_params, + const atb_speed::onerec::BlockLayerParam& param, + bool is_first_prefill, + torch::Tensor* encoder_output = nullptr, + int32_t start_tensor_idx = 0); void resize_experts_weights(int32_t num_of_device_experts); void process_expert_weights(const StateDict& state_dict, @@ -129,10 +135,16 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { std::string extract_endswith(const std::string& input); atb_speed::Model::Node prefill_node_; + atb_speed::Model::Node prefill_node_atb_; atb_speed::Model::Node decode_node_; + atb_speed::Model::Node decoder_prefill_only_decode_node_; + atb_speed::Model::Node decoder_prefill_only_decode_node_atb_; std::string model_name_; atb_speed::onerec::BlockLayerParam prefill_param_; + atb_speed::onerec::BlockLayerParam prefill_param_atb_; atb_speed::onerec::BlockLayerParam decode_param_; + atb_speed::onerec::BlockLayerParam decoder_prefill_only_decode_param_; + atb_speed::onerec::BlockLayerParam decoder_prefill_only_decode_param_atb_; atb::Tensor internal_tensors_; atb::Tensor placeholder_; @@ -140,6 +152,10 @@ class NpuOneRecBlockLayerImpl final : public BaseLayer { at::Tensor encoder_output_contiguous_; at::Tensor at_placeholder_; std::vector placeholder_vec_; + torch::Tensor cross_k_cache_; + torch::Tensor cross_v_cache_; + torch::Tensor fallback_kv_seq_lens_tensor_; + std::vector seq_lens_vec_; int32_t device_id_ = 0; bool is_decoder_ = false; diff --git a/xllm/models/rec/npu/onerec.h b/xllm/models/rec/npu/onerec.h index 96496c5ea..3452da210 100644 --- a/xllm/models/rec/npu/onerec.h +++ b/xllm/models/rec/npu/onerec.h @@ -270,8 +270,17 @@ class OneRecForConditionalGenerationImpl final void load_model(std::unique_ptr loader, std::string prefix = "model.") override { for (const auto& state_dict : loader->get_state_dicts()) { - StateDict model_state_dict = state_dict->get_dict_with_prefix(prefix); - if (model_state_dict.size() == 0) { + StateDict prefixed_state_dict = + state_dict->get_dict_with_prefix("module.module3.t5_model."); + StateDict model_state_dict = + prefixed_state_dict.size() > 0 + ? prefixed_state_dict + : state_dict->get_dict_with_prefix(prefix); + if (prefixed_state_dict.size() > 0) { + LOG(INFO) << "Detected temporary OneRec checkpoint prefix " + << "`module.module3.t5_model.`; loading weights via the " + "compatibility path."; + } else if (model_state_dict.size() == 0) { model_state_dict = *state_dict; } model_->load_state_dict(model_state_dict); diff --git a/xllm/models/rec/npu/onerec_npu_impl.h b/xllm/models/rec/npu/onerec_npu_impl.h index a04233d6e..dc3c153e9 100644 --- a/xllm/models/rec/npu/onerec_npu_impl.h +++ b/xllm/models/rec/npu/onerec_npu_impl.h @@ -251,8 +251,10 @@ class OneRecStackImpl : public torch::nn::Module { const bool is_decode_stage = is_decoder_ && !is_prefill; torch::Tensor effective_attn_mask; if (use_absolute_position_embedding_) { + const int64_t batch_size = + std::max(1, input_params.num_sequences); effective_attn_mask = - create_moe_attention_mask(query_length, h, is_decoder_); + create_moe_attention_mask(query_length, h, is_decoder_, batch_size); } else { effective_attn_mask = compute_position_bias_mask( query_length, key_length, h, is_decode_stage, input_params); @@ -382,24 +384,22 @@ class OneRecStackImpl : public torch::nn::Module { torch::Tensor create_moe_attention_mask(int64_t seq_length, const torch::Tensor& h, - bool is_decoder) const { + bool is_decoder, + int64_t batch_size) const { if (!is_decoder) { return torch::ones({num_heads_, seq_length, seq_length}, h.options()); } - const float mask_value = -9984.0f; - auto upper_tri_mask = + batch_size = std::max(1, batch_size); + torch::Tensor mask = torch::triu(torch::ones({seq_length, seq_length}, - torch::dtype(h.dtype()).device(h.device())), - 1); - auto expanded_mask = upper_tri_mask.unsqueeze(0).expand( - {num_heads_, seq_length, seq_length}); - auto effective_attn_mask = - torch::zeros({num_heads_, seq_length, seq_length}, - torch::dtype(h.dtype()).device(h.device())); - effective_attn_mask.masked_fill_(expanded_mask.to(torch::kBool), - mask_value); - return effective_attn_mask; + h.options().dtype(torch::kUInt8)), + 1) + .unsqueeze(0) + .unsqueeze(0) + .expand({batch_size, 1, seq_length, seq_length}) + .contiguous(); + return mask; } torch::Tensor compute_position_bias_mask(