-
Notifications
You must be signed in to change notification settings - Fork 1.2k
[New Feasture]: Support SP + DP parallal on Wan training #1223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
f26e1b8
7d1aa07
b2cb673
420fd08
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,65 @@ | |||||||||||||||||||||||||||||||||||||||||||||||
| from accelerate import Accelerator | ||||||||||||||||||||||||||||||||||||||||||||||||
| from .training_module import DiffusionTrainingModule | ||||||||||||||||||||||||||||||||||||||||||||||||
| from .logger import ModelLogger | ||||||||||||||||||||||||||||||||||||||||||||||||
| import time | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def inspect_batch_info(accelerator: Accelerator, batch, step): | ||||||||||||||||||||||||||||||||||||||||||||||||
| batch_type = type(batch) | ||||||||||||||||||||||||||||||||||||||||||||||||
| batch_len = len(batch) if isinstance(batch, (list, dict)) else batch.size(0) | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"[train] id{accelerator.process_index}, step{step}, prompt: {batch['prompt']}") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def build_dataloader( | ||||||||||||||||||||||||||||||||||||||||||||||||
| accelerator: Accelerator, | ||||||||||||||||||||||||||||||||||||||||||||||||
| dataset: torch.utils.data.Dataset, | ||||||||||||||||||||||||||||||||||||||||||||||||
| num_workers: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||||||
| sp_size: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||||||
| seed: int = 0, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||
| if sp_size > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||
| # When using sequence parallel, it is necessary to ensure that when the sampler uses iter to | ||||||||||||||||||||||||||||||||||||||||||||||||
| # fetch data from the dataloader, each rank within the same SP group obtains the same sample. | ||||||||||||||||||||||||||||||||||||||||||||||||
| if accelerator is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| world_size = accelerator.num_processes | ||||||||||||||||||||||||||||||||||||||||||||||||
| rank = accelerator.process_index | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Accelerator is None.") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| dp_size = world_size // sp_size | ||||||||||||||||||||||||||||||||||||||||||||||||
| if dp_size * sp_size != world_size: | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"world_size={world_size}, sp_size={sp_size}, world_size should be diviaible by sp_size" | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| dp_rank = rank // sp_size | ||||||||||||||||||||||||||||||||||||||||||||||||
| sp_rank = rank % sp_size | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"accelerator.processid={rank}, accelerator.num_processes={world_size}, " | ||||||||||||||||||||||||||||||||||||||||||||||||
| f"sp_size={sp_size}, dp_size={dp_size}, dp_rank={dp_rank}") | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| if accelerator is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| dp_size = accelerator.num_processes | ||||||||||||||||||||||||||||||||||||||||||||||||
| dp_rank = accelerator.process_index | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError(f"Accelerator is None.") | ||||||||||||||||||||||||||||||||||||||||||||||||
| print(f"dp_size={dp_size}, dp_rank={dp_rank}") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| sampler = torch.utils.data.DistributedSampler(dataset=dataset, num_replicas=dp_size, rank=dp_rank) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def worker_seed_init(seed): | ||||||||||||||||||||||||||||||||||||||||||||||||
| random.seed(seed) | ||||||||||||||||||||||||||||||||||||||||||||||||
| np.random.seed(seed) | ||||||||||||||||||||||||||||||||||||||||||||||||
| torch.manual_seed(seed) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| dataloader_kwargs = dict( | ||||||||||||||||||||||||||||||||||||||||||||||||
| dataset=dataset, | ||||||||||||||||||||||||||||||||||||||||||||||||
| sampler=sampler, | ||||||||||||||||||||||||||||||||||||||||||||||||
| num_workers=num_workers, | ||||||||||||||||||||||||||||||||||||||||||||||||
| pin_memory=True, | ||||||||||||||||||||||||||||||||||||||||||||||||
| worker_init_fn=worker_seed_init, | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few critical issues with the
Here's a suggested fix that addresses these points. Please also remember to add
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems unnecessary to specify worker_init_fn during dataloader initialization. It was only used to fix randomness when aligning loss precision, so I’ve removed it for now. |
||||||||||||||||||||||||||||||||||||||||||||||||
| collate_fn=lambda x: x[0], | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dataloader = torch.utils.data.DataLoader(**dataloader_kwargs) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| return dataloader | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def launch_training_task( | ||||||||||||||||||||||||||||||||||||||||||||||||
| accelerator: Accelerator, | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -15,6 +73,7 @@ def launch_training_task( | |||||||||||||||||||||||||||||||||||||||||||||||
| num_workers: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||||||
| save_steps: int = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| num_epochs: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||||||
| sp_size: int = 1, | ||||||||||||||||||||||||||||||||||||||||||||||||
| args = None, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||
| if args is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -23,29 +82,80 @@ def launch_training_task( | |||||||||||||||||||||||||||||||||||||||||||||||
| num_workers = args.dataset_num_workers | ||||||||||||||||||||||||||||||||||||||||||||||||
| save_steps = args.save_steps | ||||||||||||||||||||||||||||||||||||||||||||||||
| num_epochs = args.num_epochs | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| sp_size = args.sp_size | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| train_step = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| optimizer = torch.optim.AdamW(model.trainable_modules(), lr=learning_rate, weight_decay=weight_decay) | ||||||||||||||||||||||||||||||||||||||||||||||||
| scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) | ||||||||||||||||||||||||||||||||||||||||||||||||
| dataloader = build_dataloader(accelerator, dataset, num_workers, sp_size) | ||||||||||||||||||||||||||||||||||||||||||||||||
| model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| for epoch_id in range(num_epochs): | ||||||||||||||||||||||||||||||||||||||||||||||||
| for data in tqdm(dataloader): | ||||||||||||||||||||||||||||||||||||||||||||||||
| progress = tqdm( | ||||||||||||||||||||||||||||||||||||||||||||||||
| dataloader, | ||||||||||||||||||||||||||||||||||||||||||||||||
| disable=not accelerator.is_main_process, | ||||||||||||||||||||||||||||||||||||||||||||||||
| desc=f"Epoch {epoch_id + 1}/{num_epochs}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| for data in progress: | ||||||||||||||||||||||||||||||||||||||||||||||||
| inspect_batch_info(accelerator, data, train_step) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
mahaocong90 marked this conversation as resolved.
Outdated
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| iter_start = time.time() | ||||||||||||||||||||||||||||||||||||||||||||||||
| timing = {} | ||||||||||||||||||||||||||||||||||||||||||||||||
| if data is None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| with accelerator.accumulate(model): | ||||||||||||||||||||||||||||||||||||||||||||||||
| optimizer.zero_grad() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| forward_start = time.time() | ||||||||||||||||||||||||||||||||||||||||||||||||
| if dataset.load_from_cache: | ||||||||||||||||||||||||||||||||||||||||||||||||
| loss = model({}, inputs=data) | ||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||
| loss = model(data) | ||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||
| timing["forward"] = time.time() - forward_start | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| backward_start = time.time() | ||||||||||||||||||||||||||||||||||||||||||||||||
| accelerator.backward(loss) | ||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||
| timing["backward"] = time.time() - backward_start | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| optim_start = time.time() | ||||||||||||||||||||||||||||||||||||||||||||||||
| optimizer.step() | ||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||
| timing["optimizer"] = time.time() - optim_start | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| model_logger.on_step_end(accelerator, model, save_steps) | ||||||||||||||||||||||||||||||||||||||||||||||||
| scheduler.step() | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| torch.cuda.synchronize() | ||||||||||||||||||||||||||||||||||||||||||||||||
| iter_end = time.time() | ||||||||||||||||||||||||||||||||||||||||||||||||
| timing["step"] = iter_end - iter_start | ||||||||||||||||||||||||||||||||||||||||||||||||
| train_step += 1 | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if accelerator.is_main_process: | ||||||||||||||||||||||||||||||||||||||||||||||||
| def format_time(key: str) -> str: | ||||||||||||||||||||||||||||||||||||||||||||||||
| value = timing.get(key, 0.0) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return f"{value:.3f}s" | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| postfix_dict = { | ||||||||||||||||||||||||||||||||||||||||||||||||
| "loss": f"{loss.item():.5f}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| "lr": f"{optimizer.param_groups[0]['lr']:.5e}", | ||||||||||||||||||||||||||||||||||||||||||||||||
| "step/t": format_time("step"), | ||||||||||||||||||||||||||||||||||||||||||||||||
| "fwd/t": format_time("forward"), | ||||||||||||||||||||||||||||||||||||||||||||||||
| "bwd/t": format_time("backward"), | ||||||||||||||||||||||||||||||||||||||||||||||||
| "opt/t": format_time("optimizer"), | ||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||
| progress.set_postfix(postfix_dict) | ||||||||||||||||||||||||||||||||||||||||||||||||
| log_msg = f"[Step {train_step:6d}] | " + " | ".join(f"{k}: {v}" for k, v in postfix_dict.items()) | ||||||||||||||||||||||||||||||||||||||||||||||||
| progress.write(log_msg) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| if save_steps is None: | ||||||||||||||||||||||||||||||||||||||||||||||||
| model_logger.on_epoch_end(accelerator, model, epoch_id) | ||||||||||||||||||||||||||||||||||||||||||||||||
| model_logger.on_training_end(accelerator, model, save_steps) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| model_logger.on_training_end(accelerator, model, save_steps) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def launch_data_process_task( | ||||||||||||||||||||||||||||||||||||||||||||||||
| accelerator: Accelerator, | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variables
max_timestep_boundaryandmin_timestep_boundaryare already defined on lines 11-12. This recalculation is redundant and can be removed to improve code clarity and reduce duplication.