Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -657,9 +657,9 @@ DEFINE_int64(sp_size, 1, "Sequence parallelism size");

DEFINE_int64(cfg_size, 1, "Classifier-free guidiance parallelism size");

DEFINE_int64(dit_sp_communication_overlap,
1,
"Communication & Computation overlap for sequence parallel");
DEFINE_bool(dit_sp_communication_overlap,
false,
"Communication & Computation overlap for sequence parallel");

// --- dit debug ---

Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ DECLARE_int64(cfg_size);

DECLARE_bool(dit_debug_print);

DECLARE_bool(dit_sp_communication_overlap);

// --- multi-step decode config ---

DECLARE_int32(max_decode_rounds);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
register_module("scheduler", scheduler_);
register_module("transformer", transformer_);
register_module("vae_image_processor", vae_image_processor_);

use_layer3d_rope_ = context.get_model_context("transformer")
.get_model_args()
.use_layer3d_rope();
std::vector<int64_t> axes_dims_rope =
context.get_model_context("transformer")
.get_model_args()
.axes_dims_rope();
// Positional embedding
if (use_layer3d_rope_) {
pos_embed_3d_rope_ = register_module(
"pos_embed",
QwenEmbedLayer3DRope(context.get_model_context("transformer"),
/*theta=*/10000,
axes_dims_rope,
true));
} else {
pos_embed_ = register_module(
"pos_embed",
QwenEmbedRope(context.get_model_context("transformer"),
/*theta=*/10000,
axes_dims_rope,
true));
}
}

std::vector<torch::Tensor> _extract_masked_hidden(
Expand Down Expand Up @@ -463,7 +487,31 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
if (do_true_cfg && negative_prompt_embeds_mask.defined()) {
negative_txt_seq_lens = negative_prompt_embeds_mask.sum(1);
}

scheduler_->set_begin_index(0);

int64_t origin_text_seq_len = prompt_embeds.size(1);
int64_t origin_neg_text_seq_len = negative_prompt_embeds.size(1);
std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb_pos;
std::tuple<torch::Tensor, torch::Tensor> image_rotary_emb_neg;
if (use_layer3d_rope_) {
image_rotary_emb_pos = pos_embed_3d_rope_->forward(
main_shape, origin_text_seq_len, prompt_embeds.device());
image_rotary_emb_neg = pos_embed_3d_rope_->forward(
main_shape, origin_neg_text_seq_len, prompt_embeds.device());
} else {
image_rotary_emb_pos =
pos_embed_->forward(main_shape,
origin_text_seq_len,
prompt_embeds.device(),
/*max_txt_seq_len=*/std::nullopt);
image_rotary_emb_neg =
pos_embed_->forward(main_shape,
origin_neg_text_seq_len,
prompt_embeds.device(),
/*max_txt_seq_len=*/std::nullopt);
}

for (int64_t i = 0; i < timesteps.size(0); ++i) {
auto t = timesteps[i];
current_timestep_ = t;
Expand All @@ -488,6 +536,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
timestep_expanded / 1000.0,
main_shape,
txt_seq_lens,
image_rotary_emb_pos,
/*use_cfg=*/false,
/*step_index=*/i);
noise_pred = noise_pred.slice(1, 0, final_latents.size(1));
Expand All @@ -502,6 +551,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
timestep_expanded / 1000.0,
main_shape,
negative_txt_seq_lens,
image_rotary_emb_neg,
/*use_cfg=*/true,
/*step_index=*/i);

Expand All @@ -525,6 +575,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
timestep_expanded / 1000.0,
main_shape,
txt_seq_lens,
image_rotary_emb_pos,
/*use_cfg=*/false,
/*step_index=*/i);
noise_pred = noise_pred.slice(1, 0, final_latents.size(1));
Expand All @@ -535,6 +586,7 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
timestep_expanded / 1000.0,
main_shape,
negative_txt_seq_lens,
image_rotary_emb_neg,
/*use_cfg=*/true,
/*step_index=*/i);

Expand Down Expand Up @@ -610,6 +662,9 @@ class QwenImageEditPlusPipelineImpl : public QwenImagePipelineBaseImpl {
torch::Tensor current_timestep_;
string prompt_template_encode_;
const ModelArgs& vae_model_args_;
bool use_layer3d_rope_;
QwenEmbedRope pos_embed_{nullptr};
QwenEmbedLayer3DRope pos_embed_3d_rope_{nullptr};
};

REGISTER_MODEL_ARGS(Qwen2Tokenizer, [&] {});
Expand Down
Loading
Loading