Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions diffsynth/diffusion/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,50 @@


def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
from xfuser.core.distributed import (
get_sequence_parallel_rank,
get_sp_group,
)

max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))

timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)

noise = torch.randn_like(inputs["input_latents"])
input_shape = inputs["input_latents"].shape
input_dtype = inputs["input_latents"].dtype
input_device = inputs["input_latents"].device

# Random noise and timestep IDs are generated by SP local rank 0,
# and broadcast to all other ranks in the SP group.
# Alternative implementation:
# use consistent random seeds for noise and timestep_id generation within each sp group’s ranks.
if pipe.sp_size > 1:
sp_group=get_sp_group()
if get_sequence_parallel_rank() == 0:
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
else:
timestep = torch.zeros(1, dtype=pipe.torch_dtype, device=pipe.device)
sp_group.broadcast(timestep, src=0)

if get_sequence_parallel_rank() == 0:
noise = torch.randn(input_shape, dtype=input_dtype, device=input_device)
else:
noise = torch.zeros(input_shape, dtype=input_dtype, device=input_device)
sp_group.broadcast(noise, src=0)
else:
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variables max_timestep_boundary and min_timestep_boundary are already defined on lines 11-12. This recalculation is redundant and can be removed to improve code clarity and reduce duplication.

timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
noise = torch.randn(input_shape, dtype=input_dtype, device=input_device)

inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)

models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)

loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
return loss
Expand Down
1 change: 1 addition & 0 deletions diffsynth/diffusion/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def add_training_config(parser: argparse.ArgumentParser):
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay.")
parser.add_argument("--task", type=str, default="sft", required=False, help="Task type.")
parser.add_argument("--sp_size", type=int, default=1, help="Sequence size. sp size > 1 will init usp for sequence parallal.")
return parser

def add_output_config(parser: argparse.ArgumentParser):
Expand Down
122 changes: 116 additions & 6 deletions diffsynth/diffusion/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,65 @@
from accelerate import Accelerator
from .training_module import DiffusionTrainingModule
from .logger import ModelLogger
import time

def inspect_batch_info(accelerator: Accelerator, batch, step):
batch_type = type(batch)
batch_len = len(batch) if isinstance(batch, (list, dict)) else batch.size(0)
print(f"[train] id{accelerator.process_index}, step{step}, prompt: {batch['prompt']}")

def build_dataloader(
accelerator: Accelerator,
dataset: torch.utils.data.Dataset,
num_workers: int = 1,
sp_size: int = 1,
seed: int = 0,
):
if sp_size > 1:
# When using sequence parallel, it is necessary to ensure that when the sampler uses iter to
# fetch data from the dataloader, each rank within the same SP group obtains the same sample.
if accelerator is not None:
world_size = accelerator.num_processes
rank = accelerator.process_index
else:
raise ValueError(f"Accelerator is None.")

dp_size = world_size // sp_size
if dp_size * sp_size != world_size:
raise ValueError(
f"world_size={world_size}, sp_size={sp_size}, world_size should be diviaible by sp_size"
)

dp_rank = rank // sp_size
sp_rank = rank % sp_size
print(f"accelerator.processid={rank}, accelerator.num_processes={world_size}, "
f"sp_size={sp_size}, dp_size={dp_size}, dp_rank={dp_rank}")
else:
if accelerator is not None:
dp_size = accelerator.num_processes
dp_rank = accelerator.process_index
else:
raise ValueError(f"Accelerator is None.")
print(f"dp_size={dp_size}, dp_rank={dp_rank}")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements appear to be for debugging purposes. It's recommended to remove them or replace them with a proper logging framework (e.g., logging.debug(...)) to control verbosity and keep the standard output clean, especially in non-debug runs.


sampler = torch.utils.data.DistributedSampler(dataset=dataset, num_replicas=dp_size, rank=dp_rank)

def worker_seed_init(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

dataloader_kwargs = dict(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
worker_init_fn=worker_seed_init,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are a few critical issues with the worker_seed_init function and its usage that will lead to runtime errors and incorrect behavior:

  1. Missing Imports: The random and np (numpy) modules are used but not imported in this file, which will cause a NameError.
  2. Incorrect Signature: The worker_init_fn for a DataLoader receives the worker_id as an argument. The function signature should be def worker_init_fn(worker_id):.
  3. Unused Seed: The seed parameter passed to build_dataloader is not being used to seed the workers. Each worker should be seeded differently based on a base seed and its ID to ensure reproducibility.

Here's a suggested fix that addresses these points. Please also remember to add import random and import numpy as np at the top of the file.

Suggested change
def worker_seed_init(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
dataloader_kwargs = dict(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
worker_init_fn=worker_seed_init,
def worker_init_fn(worker_id):
worker_seed = seed + worker_id
random.seed(worker_seed)
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)
dataloader_kwargs = dict(
dataset=dataset,
sampler=sampler,
num_workers=num_workers,
pin_memory=True,
worker_init_fn=worker_init_fn,

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems unnecessary to specify worker_init_fn during dataloader initialization. It was only used to fix randomness when aligning loss precision, so I’ve removed it for now.

collate_fn=lambda x: x[0],
)
dataloader = torch.utils.data.DataLoader(**dataloader_kwargs)

return dataloader

def launch_training_task(
accelerator: Accelerator,
Expand All @@ -15,6 +73,7 @@ def launch_training_task(
num_workers: int = 1,
save_steps: int = None,
num_epochs: int = 1,
sp_size: int = 1,
args = None,
):
if args is not None:
Expand All @@ -23,29 +82,80 @@ def launch_training_task(
num_workers = args.dataset_num_workers
save_steps = args.save_steps
num_epochs = args.num_epochs

sp_size = args.sp_size

train_step = 0

optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)

model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
dataloader = build_dataloader(accelerator, dataset, num_workers, sp_size)
model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)

for epoch_id in range(num_epochs):
for data in tqdm(dataloader):
progress = tqdm(
dataloader,
disable=not accelerator.is_main_process,
desc=f"Epoch {epoch_id + 1}/{num_epochs}",
)

for data in progress:
inspect_batch_info(accelerator, data, train_step)
Comment thread
mahaocong90 marked this conversation as resolved.
Outdated

iter_start = time.time()
timing = {}
if data is None:
continue

with accelerator.accumulate(model):
optimizer.zero_grad()

forward_start = time.time()
if dataset.load_from_cache:
loss = model({}, inputs=data)
else:
loss = model(data)
torch.cuda.synchronize()
timing["forward"] = time.time() - forward_start

backward_start = time.time()
accelerator.backward(loss)
torch.cuda.synchronize()
timing["backward"] = time.time() - backward_start

optim_start = time.time()
optimizer.step()
torch.cuda.synchronize()
timing["optimizer"] = time.time() - optim_start

model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step()

torch.cuda.synchronize()
iter_end = time.time()
timing["step"] = iter_end - iter_start
train_step += 1

if accelerator.is_main_process:
def format_time(key: str) -> str:
value = timing.get(key, 0.0)
return f"{value:.3f}s"

postfix_dict = {
"loss": f"{loss.item():.5f}",
"lr": f"{optimizer.param_groups[0]['lr']:.5e}",
"step/t": format_time("step"),
"fwd/t": format_time("forward"),
"bwd/t": format_time("backward"),
"opt/t": format_time("optimizer"),
}
progress.set_postfix(postfix_dict)
log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items())
progress.write(log_msg)

if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)

model_logger.on_training_end(accelerator, model, save_steps)

def launch_data_process_task(
accelerator: Accelerator,
Expand Down
27 changes: 21 additions & 6 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

class WanVideoPipeline(BasePipeline):

def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16, sp_size=1):
super().__init__(
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
Expand Down Expand Up @@ -80,6 +80,7 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
WanVideoPostUnit_S2V(),
]
self.model_fn = model_fn_wan_video
self.sp_size = sp_size


def enable_usp(self):
Expand All @@ -92,7 +93,6 @@ def enable_usp(self):
for block in self.dit2.blocks:
block.self_attn.forward = types.MethodType(usp_attn_forward, block.self_attn)
self.dit2.forward = types.MethodType(usp_dit_forward, self.dit2)
self.sp_size = get_sequence_parallel_world_size()
self.use_unified_sequence_parallel = True


Expand All @@ -105,6 +105,7 @@ def from_pretrained(
audio_processor_config: ModelConfig = None,
redirect_common_files: bool = True,
use_usp: bool = False,
sp_size: int = 1,
vram_limit: float = None,
):
# Redirect model path
Expand All @@ -122,16 +123,17 @@ def from_pretrained(
print(f"To avoid repeatedly downloading model files, ({model_config.model_id}, {model_config.origin_file_pattern}) is redirected to {redirect_dict[model_config.origin_file_pattern]}. You can use `redirect_common_files=False` to disable file redirection.")
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]

if use_usp:
from ..utils.xfuser import initialize_usp
initialize_usp(device)
initialize_usp(device, sp_size)
import torch.distributed as dist
from ..core.device.npu_compatible_device import get_device_name
if dist.is_available() and dist.is_initialized():
device = get_device_name()

# Initialize pipeline
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype, sp_size=sp_size)
model_pool = pipe.download_and_load_models(model_configs, vram_limit)

# Fetch models
Expand Down Expand Up @@ -1379,7 +1381,20 @@ def custom_forward(*inputs):
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
if tea_cache is not None:
tea_cache.store(x)


'''
The all_gather interface in xDit utilizes torch’s all_gather_into_tensor interface. As of the torch 2.9 release version, this interface still does not provide a backward method and cannot support automatic autograd. The commit in the torch community (https://github.com/pytorch/pytorch/pull/168140) has not yet been merged. Therefore, a simple replacement of all_gather_into_tensor in xdit with torch.distributed.nn.functional.all_gather is applied here to enable autograd support.

def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor:
...
# All-gather.
# torch.distributed.all_gather_into_tensor(
# output_tensor, input_, group=self.device_group
# )
gathered_list = torch.distributed.nn.functional.all_gather(input_, group=self.device_group)
output_tensor = torch.cat(gathered_list, dim=0)
...
'''
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
Expand Down
40 changes: 32 additions & 8 deletions diffsynth/utils/xfuser/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,35 @@
from ...core.device import parse_nccl_backend, parse_device_type


def initialize_usp(device_type):
def initialize_usp(device_type, sp_size):
import torch.distributed as dist
from xfuser.core.distributed import initialize_model_parallel, init_distributed_environment
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")
from xfuser.core.distributed import (
initialize_model_parallel,
init_distributed_environment,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_data_parallel_world_size,
get_data_parallel_rank,
)

if not dist.is_initialized():
dist.init_process_group(backend=parse_nccl_backend(device_type), init_method="env://")

init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())

sp_degree = sp_size
dp_degree = int(dist.get_world_size() / sp_degree)
initialize_model_parallel(
sequence_parallel_degree=dist.get_world_size(),
data_parallel_degree=dp_degree,
sequence_parallel_degree=sp_degree,
ring_degree=1,
ulysses_degree=dist.get_world_size(),
ulysses_degree=sp_degree,
)
getattr(torch, device_type).set_device(dist.get_rank())

print(f"[init usp] rank: {dist.get_rank()}, world_size: {dist.get_world_size()}, "
f"sp world size: {get_sequence_parallel_world_size()}, "
f"sp rank: {get_sequence_parallel_rank()}, "
f"dp world size: {get_data_parallel_world_size()}, "
f"dp rank: {get_data_parallel_rank()}")
Comment thread
mahaocong90 marked this conversation as resolved.
Outdated

def sinusoidal_embedding_1d(dim, position):
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
Expand Down Expand Up @@ -133,6 +150,13 @@ def usp_attn_forward(self, x, freqs):
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)

'''
Refer to commit https://github.com/xdit-project/xDiT/pull/598 for the xfuser backward error.
xFuserRingFlashAttnFunc has 17 inputs (including ctx), but it inherits the backward() method from RingFlashAttnFunc which only returns 16 values (3 gradients + 13 Nones)!

The Math
Parent class (RingFlashAttnFunc): 14 forward inputs → backward returns 3 gradients + 11 Nones = 14 returns xFuser class (xFuserRingFlashAttnFunc): 17 forward inputs → backward should return 3 gradients + 14 Nones = 17 returns Actual: backward only returns 14 returns (inherited from parent without override) Error: PyTorch expects 17 gradients but gets only 14 → expected 17, got 13 (13 = 14 - 1 for ctx)
'''
x = xFuserLongContextAttention()(
None,
query=q,
Expand All @@ -143,4 +167,4 @@ def usp_attn_forward(self, x, freqs):

del q, k, v
getattr(torch, parse_device_type(x.device)).empty_cache()
return self.o(x)
return self.o(x)
Loading