diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 0d50157a5..8abb650af 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -657,9 +657,9 @@ DEFINE_int64(sp_size, 1, "Sequence parallelism size"); DEFINE_int64(cfg_size, 1, "Classifier-free guidiance parallelism size"); -DEFINE_int64(dit_sp_communication_overlap, - 1, - "Communication & Computation overlap for sequence parallel"); +DEFINE_bool(dit_sp_communication_overlap, + false, + "Communication & Computation overlap for sequence parallel"); // --- dit debug --- diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 4a55f5dd2..d9b316979 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -314,6 +314,8 @@ DECLARE_int64(cfg_size); DECLARE_bool(dit_debug_print); +DECLARE_bool(dit_sp_communication_overlap); + // --- multi-step decode config --- DECLARE_int32(max_decode_rounds); diff --git a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h index 01a722eb6..34e373091 100644 --- a/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h +++ b/xllm/models/dit/npu/qwen_image_edit/pipeline_qwenimage_edit_plus.h @@ -70,6 +70,30 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { register_module("scheduler", scheduler_); register_module("transformer", transformer_); register_module("vae_image_processor", vae_image_processor_); + + use_layer3d_rope_ = context.get_model_context("transformer") + .get_model_args() + .use_layer3d_rope(); + std::vector axes_dims_rope = + context.get_model_context("transformer") + .get_model_args() + .axes_dims_rope(); + // Positional embedding + if (use_layer3d_rope_) { + pos_embed_3d_rope_ = register_module( + "pos_embed", + QwenEmbedLayer3DRope(context.get_model_context("transformer"), + /*theta=*/10000, + axes_dims_rope, + true)); + } else { + pos_embed_ = register_module( + "pos_embed", + QwenEmbedRope(context.get_model_context("transformer"), + /*theta=*/10000, + axes_dims_rope, + true)); + } } std::vector _extract_masked_hidden( @@ -463,7 +487,31 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { if (do_true_cfg && negative_prompt_embeds_mask.defined()) { negative_txt_seq_lens = negative_prompt_embeds_mask.sum(1); } + scheduler_->set_begin_index(0); + + int64_t origin_text_seq_len = prompt_embeds.size(1); + int64_t origin_neg_text_seq_len = negative_prompt_embeds.size(1); + std::tuple image_rotary_emb_pos; + std::tuple image_rotary_emb_neg; + if (use_layer3d_rope_) { + image_rotary_emb_pos = pos_embed_3d_rope_->forward( + main_shape, origin_text_seq_len, prompt_embeds.device()); + image_rotary_emb_neg = pos_embed_3d_rope_->forward( + main_shape, origin_neg_text_seq_len, prompt_embeds.device()); + } else { + image_rotary_emb_pos = + pos_embed_->forward(main_shape, + origin_text_seq_len, + prompt_embeds.device(), + /*max_txt_seq_len=*/std::nullopt); + image_rotary_emb_neg = + pos_embed_->forward(main_shape, + origin_neg_text_seq_len, + prompt_embeds.device(), + /*max_txt_seq_len=*/std::nullopt); + } + for (int64_t i = 0; i < timesteps.size(0); ++i) { auto t = timesteps[i]; current_timestep_ = t; @@ -488,6 +536,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { timestep_expanded / 1000.0, main_shape, txt_seq_lens, + image_rotary_emb_pos, /*use_cfg=*/false, /*step_index=*/i); noise_pred = noise_pred.slice(1, 0, final_latents.size(1)); @@ -502,6 +551,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { timestep_expanded / 1000.0, main_shape, negative_txt_seq_lens, + image_rotary_emb_neg, /*use_cfg=*/true, /*step_index=*/i); @@ -525,6 +575,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { timestep_expanded / 1000.0, main_shape, txt_seq_lens, + image_rotary_emb_pos, /*use_cfg=*/false, /*step_index=*/i); noise_pred = noise_pred.slice(1, 0, final_latents.size(1)); @@ -535,6 +586,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { timestep_expanded / 1000.0, main_shape, negative_txt_seq_lens, + image_rotary_emb_neg, /*use_cfg=*/true, /*step_index=*/i); @@ -610,6 +662,9 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl { torch::Tensor current_timestep_; string prompt_template_encode_; const ModelArgs& vae_model_args_; + bool use_layer3d_rope_; + QwenEmbedRope pos_embed_{nullptr}; + QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr}; }; REGISTER_MODEL_ARGS(Qwen2Tokenizer, [&] {}); diff --git a/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h b/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h index 2d4f5b4f0..6f6e208b2 100644 --- a/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h +++ b/xllm/models/dit/npu/qwen_image_edit/transformer_qwen_image.h @@ -1096,7 +1096,8 @@ class AttentionImpl : public torch::nn::Module { bool is_causal = false, ProcessGroup* sp_group = nullptr) : options_(context.get_tensor_options()), - heads_(heads), + q_heads_(heads), + dim_head_(dim_head), bias_(bias), out_bias_(out_bias), added_proj_bias_(added_proj_bias), @@ -1158,6 +1159,7 @@ class AttentionImpl : public torch::nn::Module { int64_t q_dim = out_dim.has_value() ? out_dim.value() : dim_head * heads; int64_t kv_dim = !kv_heads.has_value() ? q_dim : dim_head * kv_heads.value(); + kv_heads_ = kv_heads.has_value() ? kv_heads.value() : heads; cross_attention_dim = cross_attention_dim.has_value() ? cross_attention_dim.value() : query_dim; @@ -1167,7 +1169,7 @@ class AttentionImpl : public torch::nn::Module { xllm::dit::SpOptions q_sp_option; xllm::dit::SpOptions kv_sp_option; xllm::dit::LinearType linear_type = xllm::dit::LinearType::Default; - if (FLAGS_sp_size > 1) { + if (FLAGS_sp_size > 1 && !FLAGS_dit_sp_communication_overlap) { q_sp_option = xllm::dit::SpOptions(/*head_num=*/heads, /*head_dim=*/dim_head, /*hidden_size=*/q_dim, @@ -1249,7 +1251,7 @@ class AttentionImpl : public torch::nn::Module { } xllm::dit::SpOptions out_sp_option; - if (FLAGS_sp_size > 1) { + if (FLAGS_sp_size > 1 && !FLAGS_dit_sp_communication_overlap) { out_sp_option = xllm::dit::SpOptions(/*head_num=*/heads, /*head_dim=*/dim_head, /*hidden_size=*/q_dim, @@ -1259,17 +1261,16 @@ class AttentionImpl : public torch::nn::Module { // Output projections if (!pre_only) { - to_out_ = register_module("to_out", torch::nn::Sequential()); - auto to_out_linear = layer::AddMatmulWeightTransposed( q_dim, out_dim.value(), out_bias, options_); - to_out_->push_back(xllm::dit::DiTParallelLinear(std::move(to_out_linear), - /*module_name=*/"out", - linear_type, - out_sp_option)); - to_out_->push_back( - torch::nn::Dropout(torch::nn::DropoutOptions(dropout))); + to_out_ = + register_module("to_out", + xllm::dit::DiTParallelLinear(std::move(to_out_linear), + /*module_name=*/"out", + linear_type, + out_sp_option)); + to_out_dropout_ = torch::nn::Dropout(torch::nn::DropoutOptions(dropout)); } // Additional output for context @@ -1302,8 +1303,7 @@ class AttentionImpl : public torch::nn::Module { void load_state_dict(const StateDict& state_dict) { // to_out - to_out_[0]->as()->load_state_dict( - state_dict.get_dict_with_prefix("to_out.0.")); + to_out_->load_state_dict(state_dict.get_dict_with_prefix("to_out.0.")); // to_add_out to_add_out_->load_state_dict( state_dict.get_dict_with_prefix("to_add_out.")); @@ -1332,8 +1332,7 @@ class AttentionImpl : public torch::nn::Module { void verify_loaded_weights(const std::string& prefix) { // to_out - to_out_[0]->as()->verify_loaded_weights( - prefix + "to_out.0."); + to_out_->verify_loaded_weights(prefix + "to_out.0."); // to_add_out to_add_out_->verify_loaded_weights(prefix + "to_add_out."); // norm_q @@ -1355,7 +1354,9 @@ class AttentionImpl : public torch::nn::Module { } public: - int64_t heads_; + int64_t q_heads_; + int64_t kv_heads_; + int64_t dim_head_; bool bias_; bool out_bias_; bool added_proj_bias_; @@ -1367,9 +1368,8 @@ class AttentionImpl : public torch::nn::Module { xllm::dit::DiTParallelLinear to_q_{nullptr}, to_k_{nullptr}, to_v_{nullptr}; xllm::dit::DiTParallelLinear add_k_proj_{nullptr}, add_v_proj_{nullptr}, add_q_proj_{nullptr}; - torch::nn::Sequential to_out_{nullptr}; - xllm::dit::DiTParallelLinear to_add_out_{nullptr}; - + xllm::dit::DiTParallelLinear to_out_{nullptr}, to_add_out_{nullptr}; + torch::nn::Dropout to_out_dropout_{nullptr}; // Assuming you have RMSNorm implemented RMSNorm norm_q_{nullptr}, norm_k_{nullptr}, norm_added_q_{nullptr}, norm_added_k_{nullptr}; @@ -1377,10 +1377,10 @@ class AttentionImpl : public torch::nn::Module { TORCH_MODULE(Attention); // Implementation of attention forward -class QwenDoubleStreamAttnProcessor2_0Impl : public torch::nn::Module { +class QwenDoubleStreamAttnProcessorBase : public torch::nn::Module { public: - QwenDoubleStreamAttnProcessor2_0Impl(Attention&& attn_module, - const ParallelArgs& parallel_args) + QwenDoubleStreamAttnProcessorBase(Attention&& attn_module, + const ParallelArgs& parallel_args) : parallel_args_(parallel_args) { attn_ = register_module("attn", std::move(attn_module)); } @@ -1390,7 +1390,266 @@ class QwenDoubleStreamAttnProcessor2_0Impl : public torch::nn::Module { const torch::Tensor& encoder_hidden_states, // Text stream const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(), const torch::Tensor& attention_mask = torch::Tensor(), - const std::tuple& image_rotary_emb = {}) { + const std::tuple& image_rotary_emb = + {}) = 0; + + virtual void load_state_dict(const StateDict& state_dict) { + attn_->load_state_dict(state_dict); + } + + virtual void verify_loaded_weights(const std::string& prefix) { + attn_->verify_loaded_weights(prefix); + } + + Attention attn_{nullptr}; + const ParallelArgs parallel_args_; +}; + +// Implementation of attention forward with communication & computation overlap +class QwenDoubleStreamAttnProcessorCMO2_0Impl final + : public QwenDoubleStreamAttnProcessorBase { + public: + QwenDoubleStreamAttnProcessorCMO2_0Impl(Attention&& attn_module, + const ParallelArgs& parallel_args) + : QwenDoubleStreamAttnProcessorBase(std::move(attn_module), + parallel_args) { + q_heads_ = attn_->q_heads_; + kv_heads_ = attn_->kv_heads_; + split_q_heads_ = q_heads_ / FLAGS_sp_size; + split_kv_heads_ = kv_heads_ / FLAGS_sp_size; + dim_head_ = attn_->dim_head_; + q_hidden_size_ = q_heads_ * dim_head_; + kv_hidden_size_ = kv_heads_ * dim_head_; + } + + std::tuple forward( + const torch::Tensor& hidden_states, // Image stream + const torch::Tensor& encoder_hidden_states, // Text stream + const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(), + const torch::Tensor& attention_mask = torch::Tensor(), + const std::tuple& image_rotary_emb = {}) + override { + // Compute QKV for image stream (sample projections) + // auto reshape_dims = std::vector{heads / FLAGS_sp_size, + // dim_head}; + + auto img_query = attn_->to_q_->forward(hidden_states); + auto query_handler = parallel_state::all_to_all_4D( + /*input=*/img_query.view({hidden_states.size(0), + -1, + q_heads_, + dim_head_}), // [B, S/sp_size, N, D] + /*scatter_dim=*/2, + /*gather_dim=*/1, + /*async_ops=*/true, + attn_->sp_group_); + + auto img_key = attn_->to_k_->forward(hidden_states); + auto key_handler = parallel_state::all_to_all_4D( + /*input=*/img_key.view({hidden_states.size(0), + -1, + kv_heads_, + dim_head_}), // [B, S/sp_size, N, D] + /*scatter_dim=*/2, + /*gather_dim=*/1, + /*async_ops=*/true, + attn_->sp_group_); + + auto img_value = attn_->to_v_->forward(hidden_states); + auto value_handler = parallel_state::all_to_all_4D( + /*input=*/img_value.view({hidden_states.size(0), + -1, + kv_heads_, + dim_head_}), // [B, S/sp_size, N, D] + /*scatter_dim=*/2, + /*gather_dim=*/1, + /*async_ops=*/true, + attn_->sp_group_); + + // Compute QKV for text stream (context projections) + auto txt_query = attn_->add_q_proj_->forward(encoder_hidden_states); + auto query_handler_txt = parallel_state::all_to_all_4D( + /*input=*/txt_query.view({hidden_states.size(0), + -1, + q_heads_, + dim_head_}), // [B, S/sp_size, N, D] + /*scatter_dim=*/2, + /*gather_dim=*/1, + /*async_ops=*/true, + attn_->sp_group_); + + auto txt_key = attn_->add_k_proj_->forward(encoder_hidden_states); + auto key_handler_txt = parallel_state::all_to_all_4D( + /*input=*/txt_key.view({hidden_states.size(0), + -1, + kv_heads_, + dim_head_}), // [B, S/sp_size, N, D] + /*scatter_dim=*/2, + /*gather_dim=*/1, + /*async_ops=*/true, + attn_->sp_group_); + + auto txt_value = attn_->add_v_proj_->forward(encoder_hidden_states); + auto value_handler_txt = parallel_state::all_to_all_4D( + /*input=*/txt_value.view({hidden_states.size(0), + -1, + kv_heads_, + dim_head_}), // [B, S/sp_size, N, D] + /*scatter_dim=*/2, + /*gather_dim=*/1, + /*async_ops=*/true, + attn_->sp_group_); + + img_query = query_handler(); // [B, S, N/sp_size, D] + img_key = key_handler(); // [B, S, N/sp_size, D] + + // Apply QK normalization + if (attn_->norm_q_) { + img_query = attn_->norm_q_->forward(img_query); + } + if (attn_->norm_k_) { + img_key = attn_->norm_k_->forward(img_key); + } + + txt_query = query_handler_txt(); // [B, S, N/sp_size, D] + txt_key = key_handler_txt(); // [B, S, N/sp_size, D] + + if (attn_->norm_added_q_) { + txt_query = attn_->norm_added_q_->forward(txt_query); + } + if (attn_->norm_added_k_) { + txt_key = attn_->norm_added_k_->forward(txt_key); + } + + // Apply RoPE if provided + auto img_freqs = std::get<0>(image_rotary_emb); + auto txt_freqs = std::get<1>(image_rotary_emb); + + img_value = value_handler(); // [B, S, N/sp_size, D] + txt_value = value_handler_txt(); // [B, S, N/sp_size, D] + + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + txt_query, /*tensor_name=*/"encoder_hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + txt_key, /*tensor_name=*/"encoder_hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + txt_value, /*tensor_name=*/"encoder_hidden_states", /*dim=*/1); + + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + img_query, /*tensor_name=*/"hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + img_key, /*tensor_name=*/"hidden_states", /*dim=*/1); + xllm::dit::SequenceParallelPadManager::getInstance().unpad_tensor( + img_value, /*tensor_name=*/"hidden_states", /*dim=*/1); + + img_query = apply_rotary_emb_qwen(img_query, img_freqs, false); + img_key = apply_rotary_emb_qwen(img_key, img_freqs, false); + txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, false); + txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, false); + + // Concatenate for joint attention - Order: [text, image] + auto joint_query = torch::cat({txt_query, img_query}, 1); + auto joint_key = torch::cat({txt_key, img_key}, 1); + auto joint_value = torch::cat({txt_value, img_value}, 1); + + auto results = at_npu::native::custom_ops::npu_fusion_attention( + joint_query, + joint_key, + joint_value, + q_heads_ / FLAGS_sp_size, + /*input_layout=*/"BSND", + /*pse=*/torch::nullopt, + /*padding_mask=*/torch::nullopt, + /*atten_mask*/ torch::nullopt, + /*scale=*/pow(joint_query.size(3), -0.5), + /*keep_prob=*/1.0, + /*pre_tokens=*/65535, + /*next_tokens=*/65535); + + auto joint_hidden_states = std::get<0>(results); + // Reshape back + joint_hidden_states = joint_hidden_states.flatten(2, 3); + joint_hidden_states = joint_hidden_states.to(joint_query.dtype()); + + int64_t seq_txt = txt_query.size(1); + int64_t seq_img = img_query.size(1); + // Split attention outputs back + auto chunks = torch::split(joint_hidden_states, {seq_txt, seq_img}, 1); + auto txt_attn_output = chunks[0]; + auto img_attn_output = chunks[1]; + + txt_attn_output = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + txt_attn_output, + /*tensor_name=*/"encoder_hidden_states", + /*dim=*/1); + + img_attn_output = + xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( + img_attn_output, /*tensor_name=*/"hidden_states", /*dim=*/1); + + // Apply output projections + auto img_out_handler = parallel_state::all_to_all_4D( + /*input=*/img_attn_output.view({hidden_states.size(0), + -1, + split_q_heads_, + dim_head_}), // [B, S, N/sp_size, D] + /*scatter_dim=*/1, + /*gather_dim=*/2, + /*async_ops=*/true, + attn_->sp_group_); + + auto txt_out_handler = parallel_state::all_to_all_4D( + /*input=*/txt_attn_output.view({hidden_states.size(0), + -1, + split_q_heads_, + dim_head_}), // [B, S, N/sp_size, D] + /*scatter_dim=*/1, + /*gather_dim=*/2, + /*async_ops=*/true, + attn_->sp_group_); + + img_attn_output = img_out_handler(); + img_attn_output = + img_attn_output.view({hidden_states.size(0), -1, q_hidden_size_}); + img_attn_output = attn_->to_out_->forward(img_attn_output); + img_attn_output = attn_->to_out_dropout_->forward(img_attn_output); + + txt_attn_output = txt_out_handler(); + txt_attn_output = + txt_attn_output.view({hidden_states.size(0), -1, q_hidden_size_}); + txt_attn_output = attn_->to_add_out_->forward(txt_attn_output); + + return std::make_tuple(img_attn_output, txt_attn_output); + } + + private: + int64_t q_heads_ = 0; + int64_t kv_heads_ = 0; + int64_t split_q_heads_ = 0; + int64_t split_kv_heads_ = 0; + int64_t dim_head_ = 0; + int64_t q_hidden_size_ = 0; + int64_t kv_hidden_size_ = 0; +}; +TORCH_MODULE(QwenDoubleStreamAttnProcessorCMO2_0); + +// Implementation of attention forward +class QwenDoubleStreamAttnProcessor2_0Impl final + : public QwenDoubleStreamAttnProcessorBase { + public: + QwenDoubleStreamAttnProcessor2_0Impl(Attention&& attn_module, + const ParallelArgs& parallel_args) + : QwenDoubleStreamAttnProcessorBase(std::move(attn_module), + parallel_args) {} + + std::tuple forward( + const torch::Tensor& hidden_states, // Image stream + const torch::Tensor& encoder_hidden_states, // Text stream + const torch::Tensor& encoder_hidden_states_mask = torch::Tensor(), + const torch::Tensor& attention_mask = torch::Tensor(), + const std::tuple& image_rotary_emb = {}) + override { // int64_t seq_txt = encoder_hidden_states.size(1); // int64_t seq_img = hidden_states.size(1); // Compute QKV for image stream (sample projections) @@ -1404,7 +1663,7 @@ class QwenDoubleStreamAttnProcessor2_0Impl : public torch::nn::Module { auto txt_value = attn_->add_v_proj_->forward(encoder_hidden_states); // Reshape for multi-head attention - int64_t heads = attn_->heads_; + int64_t heads = attn_->q_heads_; auto reshape_dims = std::vector{heads / FLAGS_sp_size, -1}; img_query = img_query.unflatten(-1, reshape_dims); @@ -1493,22 +1752,10 @@ class QwenDoubleStreamAttnProcessor2_0Impl : public torch::nn::Module { // Apply output projections img_attn_output = attn_->to_out_->forward(img_attn_output); - + img_attn_output = attn_->to_out_dropout_->forward(img_attn_output); txt_attn_output = attn_->to_add_out_->forward(txt_attn_output); return std::make_tuple(img_attn_output, txt_attn_output); } - - void load_state_dict(const StateDict& state_dict) { - attn_->load_state_dict(state_dict); - } - - void verify_loaded_weights(const std::string& prefix) { - attn_->verify_loaded_weights(prefix); - } - - protected: - Attention attn_{nullptr}; - const ParallelArgs parallel_args_; }; TORCH_MODULE(QwenDoubleStreamAttnProcessor2_0); @@ -1617,9 +1864,16 @@ class QwenImageTransformerBlockImpl : public torch::nn::Module { /*elementwise_affine=*/true, /*is_causal=*/false, /*sp_group=*/parallel_args_.dit_sp_group_); - attn_processor_ = register_module( - "attn_processor_", - QwenDoubleStreamAttnProcessor2_0(std::move(attn_), parallel_args_)); + if (!FLAGS_dit_sp_communication_overlap) { + attn_processor_ = register_module( + "attn_processor_", + QwenDoubleStreamAttnProcessor2_0(std::move(attn_), parallel_args_)); + } else { + attn_cmo_processor_ = + register_module("attn_processor_", + QwenDoubleStreamAttnProcessorCMO2_0(std::move(attn_), + parallel_args_)); + } // Image normalization 2 img_norm2_ = register_module("img_norm2", AdaLayerNorm(context, dim, eps)); @@ -1736,12 +1990,21 @@ class QwenImageTransformerBlockImpl : public torch::nn::Module { std::tie(txt_modulated, txt_gate1) = txt_norm1_->forward(encoder_hidden_states, txt_mod1); + std::tuple attn_output; // Use QwenAttnProcessor2_0 for joint attention computation - auto attn_output = attn_processor_->forward(img_modulated, // Image stream - txt_modulated, // Text stream - encoder_hidden_states_mask, - torch::Tensor(), // timestep - image_rotary_emb); + if (!FLAGS_dit_sp_communication_overlap) { + attn_output = attn_processor_->forward(img_modulated, // Image stream + txt_modulated, // Text stream + encoder_hidden_states_mask, + torch::Tensor(), // timestep + image_rotary_emb); + } else { + attn_output = attn_cmo_processor_->forward(img_modulated, // Image stream + txt_modulated, // Text stream + encoder_hidden_states_mask, + torch::Tensor(), // timestep + image_rotary_emb); + } // QwenAttnProcessor2_0 returns (img_output, txt_output) auto img_attn_output = std::get<0>(attn_output); @@ -1788,7 +2051,13 @@ class QwenImageTransformerBlockImpl : public torch::nn::Module { txt_mod_[1]->as()->load_state_dict( state_dict.get_dict_with_prefix("txt_mod.1.")); txt_mlp_->load_state_dict(state_dict.get_dict_with_prefix("txt_mlp.")); - attn_processor_->load_state_dict(state_dict.get_dict_with_prefix("attn.")); + if (!FLAGS_dit_sp_communication_overlap) { + attn_processor_->load_state_dict( + state_dict.get_dict_with_prefix("attn.")); + } else { + attn_cmo_processor_->load_state_dict( + state_dict.get_dict_with_prefix("attn.")); + } } void verify_loaded_weights(const std::string& prefix) { @@ -1798,7 +2067,11 @@ class QwenImageTransformerBlockImpl : public torch::nn::Module { txt_mod_[1]->as()->verify_loaded_weights( prefix + "txt_mod.1."); txt_mlp_->verify_loaded_weights(prefix + "txt_mlp."); - attn_processor_->verify_loaded_weights(prefix + "attn."); + if (!FLAGS_dit_sp_communication_overlap) { + attn_processor_->verify_loaded_weights(prefix + "attn."); + } else { + attn_cmo_processor_->verify_loaded_weights(prefix + "attn."); + } } private: @@ -1808,6 +2081,8 @@ class QwenImageTransformerBlockImpl : public torch::nn::Module { AdaLayerNorm img_norm2_{nullptr}; std::shared_ptr attn_{nullptr}; QwenDoubleStreamAttnProcessor2_0 attn_processor_{nullptr}; + QwenDoubleStreamAttnProcessorCMO2_0 attn_cmo_processor_{nullptr}; + // QwenDoubleStreamAttnProcessorBase attn_processor_{nullptr}; FeedForward img_mlp_{nullptr}; torch::nn::Sequential txt_mod_{nullptr}; @@ -1841,17 +2116,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module { out_channels = (out_channels > 0) ? out_channels : in_channels; auto inner_dim = num_attention_heads * attention_head_dim; - // Positional embedding - if (use_layer3d_rope_) { - pos_embed_3d_rope_ = register_module( - "pos_embed", - QwenEmbedLayer3DRope(context, /*theta=*/10000, axes_dims_rope, true)); - } else { - pos_embed_ = register_module( - "pos_embed", - QwenEmbedRope(context, /*theta=*/10000, axes_dims_rope, true)); - } - // Time-text embedding time_text_embed_ = register_module( "time_text_embed", @@ -1904,6 +2168,7 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module { torch::Tensor timestep = torch::Tensor(), std::vector> img_shapes = {}, torch::Tensor txt_seq_lens = torch::Tensor(), + const std::tuple& image_rotary_emb = {}, bool use_cfg = false, int64_t step_idx = 0, torch::Tensor addition_t_cond = torch::Tensor(), @@ -1938,8 +2203,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module { modulate_index = torch::Tensor(); } - auto origin_text_seq_len = encoder_hidden_states.size(1); - // padding mask for sequence parallel scene auto padded_encoder_hidden_states_mask = xllm::dit::SequenceParallelPadManager::getInstance().pad_tensor( @@ -1971,16 +2234,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module { padded_encoder_hidden_states_mask); auto temb = time_text_embed_->forward( new_timestep, new_hidden_states, addition_t_cond); - std::tuple image_rotary_emb; - if (use_layer3d_rope_) { - image_rotary_emb = pos_embed_3d_rope_->forward( - img_shapes, origin_text_seq_len, new_hidden_states.device()); - } else { - image_rotary_emb = pos_embed_->forward(img_shapes, - origin_text_seq_len, - new_hidden_states.device(), - /*max_txt_seq_len=*/std::nullopt); - } std::unordered_map block_attention_kwargs; if (new_encoder_hidden_states_mask.has_value() && @@ -2124,8 +2377,6 @@ class QwenImageTransformer2DModelImpl : public torch::nn::Module { private: torch::TensorOptions options_; - QwenEmbedRope pos_embed_{nullptr}; - QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr}; QwenTimestepProjEmbeddings time_text_embed_{nullptr}; RMSNorm txt_norm_{nullptr}; layer::AddMatmulWeightTransposed img_in_{nullptr};