Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions diffsynth/pipelines/wan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from PIL import Image
from einops import repeat
from typing import Optional, Union
from typing import Callable, Optional, Union
from einops import rearrange
import numpy as np
from PIL import Image
Expand Down Expand Up @@ -247,9 +247,22 @@ def __call__(
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
# Prior-based step skip: optional callback after each denoising step
step_callback: Optional[Callable[[int, torch.Tensor, torch.Tensor], None]] = None,
# Prior-based step skip: resume from saved latent (requires prior_latents + prior_timesteps)
prior_latents: Optional[torch.Tensor] = None,
prior_timesteps: Optional[torch.Tensor] = None,
prior_sigmas: Optional[torch.Tensor] = None,
start_from_step: Optional[int] = None,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)

# Prior-based step skip: override latents, timesteps, and sigmas when resuming from prior
if prior_latents is not None and prior_timesteps is not None and start_from_step is not None:
self.scheduler.timesteps = prior_timesteps.to(self.scheduler.timesteps.device)
if prior_sigmas is not None:
self.scheduler.sigmas = prior_sigmas.to(self.scheduler.sigmas.device)

# Inputs
inputs_posi = {
Expand Down Expand Up @@ -284,10 +297,18 @@ def __call__(
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)

# Prior-based step skip: replace latents with loaded prior
if prior_latents is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)
Comment on lines +301 to +302
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition for replacing latents with the prior is missing a check for prior_timesteps. The check at line 262 correctly requires prior_latents, prior_timesteps, and start_from_step. If prior_timesteps is not provided here, the scheduler will use incorrect timesteps with the loaded prior latents, which can lead to incorrect generation results. To ensure consistency and prevent bugs, the condition should be the same as the one at line 262.

Suggested change
if prior_latents is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)
if prior_latents is not None and prior_timesteps is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)


# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timesteps = self.scheduler.timesteps
start_idx = (start_from_step + 1) if start_from_step is not None else 0
for progress_id, timestep in enumerate(progress_bar_cmd(timesteps)):
if progress_id < start_idx:
continue
# Switch DiT if necessary
if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2:
self.load_models_to_device(self.in_iteration_models_2)
Expand All @@ -312,6 +333,10 @@ def __call__(
inputs_shared["latents"] = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], inputs_shared["latents"])
if "first_frame_latents" in inputs_shared:
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]

# Prior-based step skip: call optional callback after each step
if step_callback is not None:
step_callback(progress_id, inputs_shared["latents"].clone(), timestep)

# VACE (TODO: remove it)
if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
Expand Down
6 changes: 6 additions & 0 deletions docs/en/Model_Details/Wan.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ Input parameters for `WanVideoPipeline` inference include:

If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.

### Prior-Based Step Skip

For fixed identity/scene with varying motion (e.g. lip-sync, different actions), early diffusion steps are largely redundant. You can run full inference once, save latents at each step, then resume from a saved latent to run only the remaining steps — ~70% fewer steps, same quality.

See [prior-based step skip example](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/prior_based_step_skip) for `generate_prior.py` and `infer_from_prior.py`.

## Model Training

Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include:
Expand Down
6 changes: 6 additions & 0 deletions docs/zh/Model_Details/Wan.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,12 @@ DeepSpeed ZeRO 3 训练:Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将

如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。

### 基于先验的步长跳过

当身份/场景固定而仅运动变化(如口型同步、不同动作)时,早期扩散步长大多冗余。可先运行一次完整推理并保存每步的潜在表示,再从保存的潜在表示恢复,仅运行剩余步长 —— 步长减少约 70%,质量相当。

参见 [prior-based step skip 示例](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/prior_based_step_skip) 中的 `generate_prior.py` 与 `infer_from_prior.py`。

## 模型训练

Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括:
Expand Down
128 changes: 128 additions & 0 deletions examples/wanvideo/prior_based_step_skip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Prior-Based Diffusion Step Skip

**~70% fewer inference steps, same quality, zero retraining.**

When you have a **fixed identity or scene** and only **one aspect varies** (e.g. motion, lip-sync, lighting), early diffusion steps are largely redundant. This module lets you:

1. **Generate a prior** — Run full inference once, save latents at each step
2. **Infer from prior** — Load a saved latent (e.g. step 6) and run only the remaining 3–4 steps

## Quick Start

Scripts work from **repo root** or from this directory. Run from repo root for consistent paths.

### Step 1: Generate the prior

**From repo root:**

```bash
# Download example image and run full inference
python examples/wanvideo/prior_based_step_skip/generate_prior.py \
--download_example \
--output_dir ./prior_output \
--num_inference_steps 10
```

**Or with your own image:**

```bash
python examples/wanvideo/prior_based_step_skip/generate_prior.py \
--image path/to/image.jpg \
--output_dir ./prior_output \
--num_inference_steps 10
```

**From this directory:**

```bash
cd examples/wanvideo/prior_based_step_skip

# With --download_example (downloads to repo root data/)
python generate_prior.py --download_example --output_dir ./prior_output --num_inference_steps 10

# Or with your own image
python generate_prior.py --image path/to/image.jpg --output_dir ./prior_output --num_inference_steps 10
```

Output: `./prior_output/run_<id>/` with `step_0000.pt` … `step_0009.pt`, `run_metadata.json`, and `output_full.mp4`.

### Step 2: Run accelerated inference

```bash
# From repo root (replace run_<id> with actual run ID from step 1)
python examples/wanvideo/prior_based_step_skip/infer_from_prior.py \
--prior_dir ./prior_output/run_<id> \
--start_step 6 \
--image data/examples/wan/input_image.jpg \
--prompt "Different motion: the boat turns sharply to the left."
```

Or from this directory:

```bash
python infer_from_prior.py \
--prior_dir ./prior_output/run_<id> \
--start_step 6 \
--image data/examples/wan/input_image.jpg \
--prompt "Different motion: the boat turns sharply to the left."
```

This runs only 3 steps (7, 8, 9) instead of 10 — ~70% fewer steps.

## How It Works

| Steps | Content |
|---------|-----------------------------------------------|
| 1–5 | Identity formation (geometry, lighting) |
| **6** | **Inflection point** — identity formed, motion not yet committed |
| 7–10 | Temporal refinement (details, sharpness) |

By injecting the latent at step 6, we skip redundant identity formation. The remaining steps refine the motion (or other varying aspect) driven by the new prompt.

## Scripts

| Script | Purpose |
|---------------------|--------------------------------------------------------|
| `generate_prior.py` | Full inference with latent saving at each step |
| `infer_from_prior.py` | Accelerated inference from a saved prior |
| `prior_utils.py` | Latent save/load, metadata, scheduler validation |

## Options

### generate_prior.py

- `--image` — Input image (required unless `--download_example`)
- `--download_example` — Download example image from ModelScope (saves to `data/examples/wan/`)
- `--output_dir` — Where to save latents (default: `./prior_output`)
- `--num_inference_steps` — Total steps (default: 10)
- `--start_step` — Not used here; for reference when calling infer_from_prior
- `--save_decoded_videos` — Decode and save video at each step (for finding formation point)

### infer_from_prior.py

- `--prior_dir` — Path to prior run (e.g. `./prior_output/run_123`)
- `--start_step` — Step to resume from (default: 6)
- `--image` — Same image used for prior generation
- `--prompt` — New prompt for the varying aspect

## Scheduler Identity

The scheduler used during prior generation **must match** inference. The scripts save and validate:

- `num_inference_steps`
- `denoising_strength`
- `sigma_shift`
- `scheduler_timesteps` and `scheduler_sigmas`

Do not change these between prior generation and inference.

## Requirements

- DiffSynth-Studio installed (`pip install -e .` from repo root)
- GPU with ≥8GB VRAM (low-VRAM config uses disk offload)
- Wan2.1-I2V-14B-480P model (downloaded automatically from ModelScope)

## See Also

- [Wan model documentation](../../../docs/en/Model_Details/Wan.md)
- [Model inference examples](../model_inference_low_vram/)
Loading