Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger

import time
from ..utils.profiling.flops_profiler import (
print_model_profile,
get_flops,
profile_entire_model,
unprofile_entire_model,
)

def launch_training_task(
accelerator: Accelerator,
Expand All @@ -29,21 +35,69 @@ def launch_training_task(
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
model.to(device=accelerator.device)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)


train_step = 0
profile_entire_model(model)

for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
progress = tqdm(
dataloader,
disable=not accelerator.is_main_process,
desc=f"Epoch {epoch_id + 1}/{num_epochs}",
)

for data in progress:
iter_start = time.time()
timing = {}
if data is None:
continue

with accelerator.accumulate(model):
optimizer.zero_grad()

if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)

t5_Tflops, wan_Tflops, vae_Tflops = get_flops(model)
accelerator.backward(loss)
optimizer.step()

model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
scheduler.step()

torch.cuda.synchronize()
iter_end = time.time()
timing["step"] = iter_end - iter_start
train_step += 1

total_flops = t5_Tflops + wan_Tflops + vae_Tflops
TFLOPS = total_flops * 3 / timing["step"]

if accelerator.is_main_process:
def format_time(key: str) -> str:
value = timing.get(key, 0.0)
return f"{value:.3f}s"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function is redefined on every iteration of the training loop, which is inefficient. It's better to define it once outside the loop. To do so, you'll need to pass the timing dictionary as an argument.

For example, you could define it before the loop:

def format_time(timing_dict: dict, key: str) -> str:
    value = timing_dict.get(key, 0.0)
    return f"{value:.3f}s"

And then call it inside the loop as format_time(timing, "step").

Since I cannot suggest changes outside of the current diff hunk, I'm leaving this as a comment for you to refactor.


postfix_dict = {
"Rank": f"{accelerator.process_index}",
"loss": f"{loss.item():.5f}",
"lr": f"{optimizer.param_groups[0]['lr']:.5e}",
"step/t": format_time("step"),
"[t5] Tflops": f"{t5_Tflops:.3f}",
"[dit] Tflops": f"{wan_Tflops:.3f}",
"[vae] Tflops": f"{vae_Tflops:.3f}",
"TFLOPS": f"{TFLOPS:.3f}",
}
progress.set_postfix(postfix_dict)
log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items())
progress.write(log_msg)

if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)

unprofile_entire_model(model)
model_logger.on_training_end(accelerator, model, save_steps)


Expand Down
7 changes: 6 additions & 1 deletion diffsynth/models/wan_video_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def forward(self, x, context=None, mask=None, pos_bias=None):
k = self.k(context).view(b, -1, n, c)
v = self.v(context).view(b, -1, n, c)

# For caculate flops
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a typo in the comment. "caculate" should be "calculate".

Suggested change
# For caculate flops
# For calculate flops

self.q_shape = q.shape
self.k_shape = k.shape
self.v_shape = v.shape

# attention bias
attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
if pos_bias is not None:
Expand Down Expand Up @@ -327,4 +332,4 @@ def _clean(self, text):
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == 'canonicalize':
text = canonicalize(basic_clean(text))
return text
return text
6 changes: 6 additions & 0 deletions diffsynth/utils/profiling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .flops_profiler import (
profile_entire_model,
unprofile_entire_model,
get_flops,
print_model_profile,
)
252 changes: 252 additions & 0 deletions diffsynth/utils/profiling/flops_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
import torch
import torch.nn as nn
from functools import wraps
import time
from collections import defaultdict
import flash_attn
from einops import rearrange
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These imports (defaultdict, flash_attn, einops) do not appear to be used in this file and can be removed to improve code clarity.

from torch.utils.flop_counter import conv_flop_count

def get_dit_flops(model):
def get_dit_flops(dit_block_model):
total_flops = 0
for sub_model in dit_block_model.modules():
total_flops += getattr(sub_model, '__flops__', 0)
return total_flops

total_flops = 0
total_duration = 0
for sub_module in model.modules():
if sub_module.__class__.__name__ == 'DiTBlock':
total_flops += get_dit_flops(sub_module)
total_duration += getattr(sub_module, '__duration__', 0)

Tflops = total_flops / 1e12
return Tflops
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function has a couple of issues affecting readability and maintainability:

  1. The nested function at line 11 has the same name as the outer function, which can be confusing. It's better to give it a more descriptive name, like _get_dit_flops_recursive.
  2. The total_duration variable is calculated but never used. It should be removed.
Suggested change
def get_dit_flops(model):
def get_dit_flops(dit_block_model):
total_flops = 0
for sub_model in dit_block_model.modules():
total_flops += getattr(sub_model, '__flops__', 0)
return total_flops
total_flops = 0
total_duration = 0
for sub_module in model.modules():
if sub_module.__class__.__name__ == 'DiTBlock':
total_flops += get_dit_flops(sub_module)
total_duration += getattr(sub_module, '__duration__', 0)
Tflops = total_flops / 1e12
return Tflops
def get_dit_flops(model):
def _get_dit_flops_recursive(dit_block_model):
total_flops = 0
for sub_model in dit_block_model.modules():
total_flops += getattr(sub_model, '__flops__', 0)
return total_flops
total_flops = 0
for sub_module in model.modules():
if sub_module.__class__.__name__ == 'DiTBlock':
total_flops += _get_dit_flops_recursive(sub_module)
Tflops = total_flops / 1e12
return Tflops


def get_flops(model):
def get_module_flops(module):
if not hasattr(module, "__flops__"):
module.__flops__ = 0

flops = module.__flops__
# iterate over immediate children modules
for child in module.children():
flops += get_module_flops(child)
return flops

t5_flops = 0
wan_flops = 0
vae_flops = 0
for module in model.modules():
if module.__class__.__name__ == 'WanTextEncoder':
t5_flops = get_module_flops(module)
if module.__class__.__name__ == 'WanModel':
wan_flops = get_module_flops(module)
if module.__class__.__name__ == 'WanVideoVAE38':
vae_flops = get_module_flops(module)
return t5_flops / 1e12, wan_flops / 1e12, vae_flops / 1e12

def print_model_profile(model):
def get_module_flops(module):
if not hasattr(module, "__flops__"):
module.__flops__ = 0

flops = module.__flops__
# iterate over immediate children modules
for child in module.children():
flops += get_module_flops(child)
return flops

def get_module_duration(module):
if not hasattr(module, "__duration__"):
module.__duration__ = 0

duration = module.__duration__
if duration == 0: # e.g. ModuleList
for m in module.children():
duration += get_module_duration(m)
return duration

def flops_repr(module):
flops = get_module_flops(module)
duration = get_module_duration(module) * 1000
items = [
"{:,} flops".format(flops),
"{:.3f} ms".format(duration),
]
original_extra_repr = module.original_extra_repr()
if original_extra_repr:
items.append(original_extra_repr)
return ", ".join(items)

def add_extra_repr(module):
flops_extra_repr = flops_repr.__get__(module)
if module.extra_repr != flops_extra_repr:
module.original_extra_repr = module.extra_repr
module.extra_repr = flops_extra_repr
assert module.extra_repr != module.original_extra_repr

def del_extra_repr(module):
if hasattr(module, "original_extra_repr"):
module.extra_repr = module.original_extra_repr
del module.original_extra_repr

model.apply(add_extra_repr)
print(model)
model.apply(del_extra_repr)

def get_module_flops(module, *args, result=None, **kwargs):
module_type = module.__class__.__name__
module_original_fwd = module._original_forward.__name__

if module_type == 'RMSNorm':
x = args[0]
return x.numel() * 4

elif module_type == 'RMS_norm':
x = args[0]
return x.numel() * 4

elif module_type == 'Dropout':
x = args[0]
return x.numel() * 2

elif module_type == 'LayerNorm':
x = args[0]
has_affine = module.weight is not None
return x.numel() * (5 if has_affine else 4)

elif module_type == 'Linear':
x = args[0]
return x.numel() * module.weight.size(0) * 2

elif module_type == 'ReLU':
x = args[0]
return x.numel()

elif module_type == 'GELU':
x = args[0]
return x.numel()

elif module_type == 'SiLU':
x = args[0]
return x.numel()

elif module_type == 'Conv3d' or module_type == 'CausalConv3d' or module_type == 'Conv2d':
x_shape = args[0].shape
weight = getattr(module, 'weight', None)
w_shape = weight.shape
out_shape = result.shape

flops = conv_flop_count(
x_shape=x_shape,
w_shape=w_shape,
out_shape=out_shape,
transposed=False
)
return flops

# AttentionModule input is 3D shape, USP input is 4D shape.
#
# 3D shape:
# q [batch, target_seq_len, Dim]
# k [batch, source_seq_len, Dim]
# v [batch, source_seq_len, Dim]
# flops = (batch * target_seq_len * source_seq_len) * Dim * 2
# + (batch * target_seq_len * Dim) * source_seq_len * 2
# = 4 * (batch * target_seq_len * source_seq_len * Dim)
#
# 4D shape:
# q [batch, target_seq_len, head, dim]
# k [batch, source_seq_len, head, dim]
# v [batch, source_seq_len, head, dim]
# flops = 4 * (batch * target_seq_len * source_seq_len * head * dim)
#
elif module_type == 'AttentionModule':
q = args[0]
k = args[1]
v = args[2]

b, ts, dq = q.shape
_, ss, _ = k.shape
_, _, dv = v.shape
flops = (b * ts * ss * dq) * 2 + (b * ts * ss * dv) * 2
return flops

elif module_original_fwd == 'usp_attn_forward' or module_type == 'T5Attention':
q_shape = module.q_shape
k_shape = module.k_shape
v_shape = module.v_shape

b, ts, n, dq = q_shape
_, ss, _, _ = k_shape
_, _, _, dv = v_shape
flops = (b * ts * ss * n * dq) * 2 + (b * ts * ss * n * dv) * 2
return flops

elif module_type == 'GateModule':
x = args[0]
return x.numel() * 2

elif module_type == 'T5LayerNorm':
x = args[0]
return x.numel() * 4

elif module_type == 'T5RelativeEmbedding':
lq = args[0]
lk = args[1]
return lq * lk * 10

else:
return 0

def flops_counter(flops_func=None):
def decorator(forward_func):
@wraps(forward_func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()

result = forward_func(self, *args, **kwargs)

self.__flops__ = get_module_flops(self, *args, result=result, **kwargs)

end_time = time.perf_counter()
self.__duration__ = (end_time - start_time)

return result
return wrapper
return decorator


def wrap_existing_module(module, verbose_profiling=False):
# save original fwd
module.verbose_profiling = verbose_profiling
module._original_forward = module.forward

@flops_counter()
def profiled_forward(self, x, *args, **kwargs):
return module._original_forward(x, *args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The signature of profiled_forward is (self, x, *args, **kwargs), which assumes that every wrapped module's forward method has a first positional argument x. This is not always true and can lead to a TypeError for modules with different forward signatures (e.g., no arguments, or keyword-only arguments). The signature should be (self, *args, **kwargs) to be generic and robust.

Suggested change
def profiled_forward(self, x, *args, **kwargs):
return module._original_forward(x, *args, **kwargs)
def profiled_forward(self, *args, **kwargs):
return module._original_forward(*args, **kwargs)


module.forward = profiled_forward.__get__(module, type(module))
return module

def profile_entire_model(model, verbose_profiling=True):
for name, module in model.named_modules():
wrap_existing_module(module, verbose_profiling)
return model

def unwrap_existing_module(module):
if hasattr(module, "_original_forward"):
module.forward = module._original_forward
del module._original_forward

if hasattr(module, "verbose_profiling"):
del module.verbose_profiling
return module

def unprofile_entire_model(model):
for name, module in model.named_modules():
unwrap_existing_module(module)
return model

7 changes: 6 additions & 1 deletion diffsynth/utils/xfuser/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)

# For caculate flops
self.q_shape = q.shape
self.k_shape = k.shape
self.v_shape = v.shape

attn_type = AttnType.FA
ring_impl_type = "basic"
if IS_NPU_AVAILABLE:
Expand All @@ -143,4 +148,4 @@ def usp_attn_forward(self, x, freqs):

del q, k, v
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)
return self.o(x)