diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/README.md b/sd-legacy-stable-diffusion-v1-5/VitisAI/README.md new file mode 100644 index 00000000..231265c6 --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/README.md @@ -0,0 +1,86 @@ +## Stable Diffusion Optimization with ONNX Runtime VitisAI EP + +This folder contains sample Olive configurations to optimize **Stable Diffusion v1.5** subgraphs for the **VitisAI Execution Provider** on AMD NPU. + +## Supported models and configs + +| Model ID (Hugging Face) | Config file | +|:---------------------|:------------| +| `sd-legacy/stable-diffusion-v1-5` | `config_unet.json` | +| `sd-legacy/stable-diffusion-v1-5` | `config_vae_decoder.json` | +| `sd-legacy/stable-diffusion-v1-5` | `config_vae_encoder.json` | +| `sd-legacy/stable-diffusion-v1-5` | `config_text_encoder.json` | +| `sd-legacy/stable-diffusion-v1-5` | `config_safety_checker.json` | + +## Run the VitisAI workflow + +#### Create a conda environment and install Olive + +```bash +conda create -n olive python=3.12,pip +conda activate olive +``` + +```bash +git clone https://github.com/microsoft/Olive.git +cd Olive +pip install -e . +``` + +#### Install VitisAI Stable Diffusion dependencies + +```bash +git clone https://github.com/microsoft/olive-recipes.git +cd olive-recipes/sd-legacy-stable-diffusion-v1-5/VitisAI +pip install --force-reinstall -r requirements_vitisai_sd.txt +``` + +## Generate optimized subgraphs (optional) + +Run Olive to generate NPU-ready optimized submodels. + +> **Note:** This step is optional. If you only need the optimized ONNX models for NPU, you can run this step alone without the full pipeline. + +```bash +cd olive-recipes/sd-legacy-stable-diffusion-v1-5/VitisAI + +olive run --config config_unet.json +olive run --config config_vae_decoder.json +olive run --config config_vae_encoder.json +olive run --config config_text_encoder.json +olive run --config config_safety_checker.json +``` + +Optimized artifacts are written to the `output_dir` defined in each JSON (for example `footprints/unet`, `footprints/vae_decoder`, …). + +> **Note:** Exact paths depend on `output_dir` and `cache_dir` in each config file. + +### Execution provider and hardware placement + +| Component | Execution provider | Compute device | +|-----------|-------------------|----------------| +| UNet | VitisAI EP | NPU | +| VAE decoder | VitisAI EP | NPU | +| Text encoder | CPU EP | CPU | +| VAE encoder | CPU EP | CPU | +| Safety checker | CPU EP | CPU | + +The VitisAI Execution Provider is used only for the **UNet** and **VAE decoder**. All other subgraphs run with the **CPU Execution Provider** on the host CPU. + +### End-to-end image generation (inference) + +```bash +cd olive-recipes/sd-legacy-stable-diffusion-v1-5/VitisAI + +python stable_diffusion.py --provider vitisai --model_id sd-legacy/stable-diffusion-v1-5 --seed 0 --guidance_scale 7.5 --num_inference_steps 20 --prompt "Photo of an ultra realistic sailing ship, dramatic light, pale sunrise, cinematic lighting, battered, low angle, trending on artstation, 4k, hyper realistic, focused, extreme details, unreal engine 5, cinematic, masterpiece, art by studio ghibli, intricate artwork by john william turner." +``` + +## Outputs (relative to `VitisAI/`) + +| Item | Location | +|:-----|:---------| +| Generated images | `result_0.png`, `result_1.png`, … in the **current working directory** (typically `VitisAI/` if you run the command from there) | +| Full pipeline, unoptimized | `model/unoptimized//` | +| Full pipeline, optimized (VitisAI) | `model/optimized-vitisai//` | + +`model_id` slashes become nested folders (e.g. `sd-legacy/stable-diffusion-v1-5`). Per-subgraph `olive run` outputs use each config’s `output_dir` / `cache_dir` (e.g. under `footprints/`, `vai_cache/`). diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/assets/result.png b/sd-legacy-stable-diffusion-v1-5/VitisAI/assets/result.png new file mode 100644 index 00000000..b5ff4ba5 Binary files /dev/null and b/sd-legacy-stable-diffusion-v1-5/VitisAI/assets/result.png differ diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/config_safety_checker.json b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_safety_checker.json new file mode 100644 index 00000000..5111351b --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_safety_checker.json @@ -0,0 +1,55 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "sd-legacy/stable-diffusion-v1-5", + "model_loader": "safety_checker_load", + "script_dir": ".", + "model_script": "user_script.py", + "io_config": { + "input_names": [ "clip_input", "images" ], + "output_names": [ "out_images", "has_nsfw_concepts" ], + "dynamic_axes": { + "clip_input": { "0": "batch", "1": "channels", "2": "height", "3": "width" }, + "images": { "0": "batch", "1": "height", "2": "width", "3": "channels" } + } + }, + "dummy_inputs_func": "safety_checker_conversion_inputs" + }, + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 14 }, + "optimize": { + "type": "OrtTransformersOptimization", + "model_type": "unet", + "opt_level": 0, + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false, + "group_norm_channels_last": false + }, + "force_fp32_ops": [ "RandomNormalLike" ], + "force_fp16_inputs": { "GroupNorm": [ 0, 1, 2 ] } + } + }, + "log_severity_level": 0, + "cache_dir": "vai_cache", + "output_dir": "footprints/safety_checker" +} diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/config_text_encoder.json b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_text_encoder.json new file mode 100644 index 00000000..24b95df2 --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_text_encoder.json @@ -0,0 +1,57 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "sd-legacy/stable-diffusion-v1-5", + "model_loader": "text_encoder_load", + "script_dir": ".", + "model_script": "user_script.py", + "io_config": { + "input_names": [ "input_ids" ], + "output_names": [ "last_hidden_state", "pooler_output" ], + "dynamic_axes": { "input_ids": { "0": "batch", "1": "sequence" } } + }, + "dummy_inputs_func": "text_encoder_conversion_inputs" + }, + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 14 }, + "optimize": { + "type": "OrtTransformersOptimization", + "model_type": "clip", + "opt_level": 0, + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false, + "group_norm_channels_last": false + }, + "force_fp32_ops": [ "RandomNormalLike" ], + "force_fp16_inputs": { "GroupNorm": [ 0, 1, 2 ] } + }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ "batch", "sequence" ], + "dim_value": [ 1, 77 ] + } + }, + "log_severity_level": 0, + "cache_dir": "vai_cache", + "output_dir": "footprints/text_encoder" +} diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/config_unet.json b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_unet.json new file mode 100644 index 00000000..7b410d92 --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_unet.json @@ -0,0 +1,40 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "sd-legacy/stable-diffusion-v1-5", + "model_loader": "unet_load", + "script_dir": ".", + "model_script": "user_script.py", + "io_config": { + "input_names": [ "sample", "timestep", "encoder_hidden_states", "return_dict" ], + "output_names": [ "out_sample" ], + "dynamic_axes": { + "sample": { + "0": "batch", + "1": "channels", + "2": "height", + "3": "width" + }, + "encoder_hidden_states": { "0": "batch", "1": "sequence" } + } + }, + "dummy_inputs_func": "unet_conversion_inputs" + }, + "passes": { + "convert": { + "type": "OnnxConversion", + "target_opset": 17, + "save_as_external_data": true, + "all_tensors_to_one_file": true, + "external_data_name": "weights.pb" + }, + "model_generation": { + "type": "VitisGenerateModelSD", + "model_type": "sd-unet", + "resolutions": ["512x512"] + } + }, + "log_severity_level": 0, + "cache_dir": "vai_cache", + "output_dir": "footprints/unet" +} diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/config_vae_decoder.json b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_vae_decoder.json new file mode 100644 index 00000000..46ad2178 --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_vae_decoder.json @@ -0,0 +1,33 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "sd-legacy/stable-diffusion-v1-5", + "model_loader": "vae_decoder_load", + "script_dir": ".", + "model_script": "user_script.py", + "io_config": { + "input_names": [ "latent_sample", "return_dict" ], + "output_names": [ "sample" ], + "dynamic_axes": { + "latent_sample": { + "0": "batch", + "1": "channels", + "2": "height", + "3": "width" + } + } + }, + "dummy_inputs_func": "vae_decoder_conversion_inputs" + }, + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 17 }, + "model_generation": { + "type": "VitisGenerateModelSD", + "model_type": "sd15-vae-decoder", + "resolutions": ["512x512"] + } + }, + "log_severity_level": 0, + "cache_dir": "vai_cache", + "output_dir": "footprints/vae_decoder" +} diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/config_vae_encoder.json b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_vae_encoder.json new file mode 100644 index 00000000..aa3d197a --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/config_vae_encoder.json @@ -0,0 +1,68 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "sd-legacy/stable-diffusion-v1-5", + "model_loader": "vae_encoder_load", + "script_dir": ".", + "model_script": "user_script.py", + "io_config": { + "input_names": [ "sample", "return_dict" ], + "output_names": [ "latent_sample" ], + "dynamic_axes": { + "sample": { "0": "encoder_batch", "1": "encoder_channels", "2": "encoder_height", "3": "encoder_width" } + } + }, + "dummy_inputs_func": "vae_encoder_conversion_inputs" + }, + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 17 }, + "optimize": { + "type": "OrtTransformersOptimization", + "model_type": "vae", + "opt_level": 0, + "float16": true, + "use_gpu": true, + "keep_io_types": false, + "optimization_options": { + "enable_gelu": true, + "enable_layer_norm": true, + "enable_attention": true, + "use_multi_head_attention": true, + "enable_skip_layer_norm": false, + "enable_embed_layer_norm": true, + "enable_bias_skip_layer_norm": false, + "enable_bias_gelu": true, + "enable_gelu_approximation": false, + "enable_qordered_matmul": false, + "enable_shape_inference": true, + "enable_gemm_fast_gelu": false, + "enable_nhwc_conv": false, + "enable_group_norm": true, + "enable_bias_splitgelu": false, + "enable_packed_qkv": true, + "enable_packed_kv": true, + "enable_bias_add": false, + "group_norm_channels_last": false + }, + "force_fp32_ops": [ "RandomNormalLike" ], + "force_fp16_inputs": { "GroupNorm": [ 0, 1, 2 ] } + }, + "dynamic_shape_to_fixed": { + "type": "DynamicToFixedShape", + "dim_param": [ + "encoder_batch", + "encoder_channels", + "encoder_height", + "encoder_width", + "Addlatent_sample_dim_0", + "Addlatent_sample_dim_1", + "Addlatent_sample_dim_2", + "Addlatent_sample_dim_3" + ], + "dim_value": [ 1, 3, 512, 512, 1, 4, 64, 64 ] + } + }, + "log_severity_level": 0, + "cache_dir": "vai_cache", + "output_dir": "footprints/vae_encoder" +} diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/model_adaptations.py b/sd-legacy-stable-diffusion-v1-5/VitisAI/model_adaptations.py new file mode 100644 index 00000000..945237ad --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/model_adaptations.py @@ -0,0 +1,605 @@ +# --------------------------------------------------------------------- +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# --------------------------------------------------------------------- +import math +import types +from typing import Callable, Optional + +import diffusers.models.attention_processor as attention_processor +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers import UNet2DConditionModel +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention import BasicTransformerBlock, FeedForward +from diffusers.models.transformers.transformer_2d import Transformer2DModel + + +class Conv2dLinear(torch.nn.Module): + """A class to convert a Linear layer to a Conv2D layer with a 1x1 kernel. + This allows the linear transformation to be applied to the channel dimension + at each spatial location in the input tensor. + + Args: + linear (nn.Linear): The original linear layer to be converted. + + """ + + def __init__(self, linear: torch.nn.Linear): + super().__init__() + self.in_features = linear.in_features + self.out_features = linear.out_features + + # Initialize a Conv2D layer with a 1x1 kernel to mimic the Linear layer + self.conv = torch.nn.Conv2d( + in_channels=self.in_features, + out_channels=self.out_features, + kernel_size=1, + bias=(linear.bias is not None), + ) + + # Copy the weights from the Linear layer to the Conv2D layer + self.conv.weight.data.copy_(linear.weight.data.view(self.out_features, self.in_features, 1, 1)) + + # Copy the bias if it exists + if linear.bias is not None: + self.conv.bias.data.copy_(linear.bias.data) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward-pass routine for the Conv2D layer. + + Args: + x (torch.Tensor): The input tensor in NCHW format. + + Returns: + torch.Tensor: The output tensor after applying the Conv2D transformation. + + """ + return self.conv(x) + + +class SHAAttention(nn.Module): + """Split-Head Attention with per-head Conv2D projections and a single output + Conv2D projection. This implementation splits the attention heads into + separate Conv2D projection layers and applies a single output projection + after concatenating all heads. Adjusted to handle spatial dimensions (H, + W) instead of sequence length. + """ + + def __init__(self, orig_attn: attention_processor.Attention): + """Initialize SHAAttention by copying weights from an existing Attention module. + + Args: + orig_attn (attention_processor.Attention): The original Attention module to be replaced. + + """ + super().__init__() + + for f in ["group_norm", "spatial_norm", "norm_q", "norm_k", "norm_cross"]: + if getattr(orig_attn, f) is not None: + raise NotImplementedError(f"{f} is not supported") + + # Copy configuration from the original Attention module + self.heads = orig_attn.heads + self.kv_heads = int(orig_attn.inner_kv_dim / orig_attn.inner_dim * self.heads) + + # Infer dim_head from to_q.out_features and heads + if orig_attn.to_q.out_features % self.heads != 0: + raise ValueError("to_q.out_features is not divisible by heads. Cannot infer dim_head.") + self.dim_head = orig_attn.to_q.out_features // self.heads + self.scale = 1 / math.sqrt(self.dim_head) + self.rescale_output_factor_inv = 1 / orig_attn.rescale_output_factor + + self.residual_connection = orig_attn.residual_connection + + # Verify to_k and to_v dimensions + expected_kv_out = self.kv_heads * self.dim_head + if orig_attn.to_k.out_features != expected_kv_out: + raise ValueError( + f"to_k.out_features ({orig_attn.to_k.out_features}) does not match expected {expected_kv_out}." + ) + if orig_attn.to_v.out_features != expected_kv_out: + raise ValueError( + f"to_v.out_features ({orig_attn.to_v.out_features}) does not match expected {expected_kv_out}." + ) + + # Initialize separate Conv2D projection layers for each head + self.q_proj_sha = nn.ModuleList( + [ + nn.Conv2d( + orig_attn.to_q.in_features, + self.dim_head, + kernel_size=1, + bias=(orig_attn.to_q.bias is not None), + ) + for _ in range(self.heads) + ] + ) + self.k_proj_sha = nn.ModuleList( + [ + nn.Conv2d( + orig_attn.to_k.in_features, + self.dim_head, + kernel_size=1, + bias=(orig_attn.to_k.bias is not None), + ) + for _ in range(self.kv_heads) + ] + ) + self.v_proj_sha = nn.ModuleList( + [ + nn.Conv2d( + orig_attn.to_v.in_features, + self.dim_head, + kernel_size=1, + bias=(orig_attn.to_v.bias is not None), + ) + for _ in range(self.kv_heads) + ] + ) + + self.to_out = orig_attn.to_out + + # Copy weights from the original shared Linear projections to the separate Conv2D projections + for i in range(self.heads): + # Query Projection + q_weight = orig_attn.to_q.weight.data[i * self.dim_head : (i + 1) * self.dim_head, :].clone() + q_weight = q_weight.unsqueeze(-1).unsqueeze(-1) # Shape: (dim_head, in_features, 1, 1) + self.q_proj_sha[i].weight.data.copy_(q_weight) + if orig_attn.to_q.bias is not None: + self.q_proj_sha[i].bias.data.copy_( + orig_attn.to_q.bias.data[i * self.dim_head : (i + 1) * self.dim_head].clone() + ) + + for i in range(self.kv_heads): + # Key Projection + k_weight = orig_attn.to_k.weight.data[i * self.dim_head : (i + 1) * self.dim_head, :].clone() + k_weight = k_weight.unsqueeze(-1).unsqueeze(-1) # Shape: (dim_head, in_features, 1, 1) + self.k_proj_sha[i].weight.data.copy_(k_weight) + if orig_attn.to_k.bias is not None: + self.k_proj_sha[i].bias.data.copy_( + orig_attn.to_k.bias.data[i * self.dim_head : (i + 1) * self.dim_head].clone() + ) + + # Value Projection + v_weight = orig_attn.to_v.weight.data[i * self.dim_head : (i + 1) * self.dim_head, :].clone() + v_weight = v_weight.unsqueeze(-1).unsqueeze(-1) # Shape: (dim_head, in_features, 1, 1) + self.v_proj_sha[i].weight.data.copy_(v_weight) + if orig_attn.to_v.bias is not None: + self.v_proj_sha[i].bias.data.copy_( + orig_attn.to_v.bias.data[i * self.dim_head : (i + 1) * self.dim_head].clone() + ) + + del orig_attn.to_q + del orig_attn.to_k + del orig_attn.to_v + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """Forward pass for Split-Head Cross Attention. + Processes each head separately using head-specific Conv2D projection layers. + + Args: + hidden_states (torch.Tensor): The hidden states (batch_size, hidden_size, H, W). + attention_mask (Optional[torch.Tensor]): The attention mask. + encoder_hidden_states (Optional[torch.Tensor]): The encoder hidden states for cross-attention. + **kwargs: Additional keyword arguments. + + Returns: + Tuple containing the attention output, attention weights, and past key-value. + + """ + bsz, hidden_size, H, W = hidden_states.size() + residual = hidden_states + + if encoder_hidden_states is not None: + # (N, seq_len, inner_dim) to (N, inner_dim, 1, seq_len) + encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1).unsqueeze(2) + # encoder_hidden_states: (N, inner_dim, 1, seq_len) + else: + encoder_hidden_states = hidden_states + + query_states = [q_proj(hidden_states) for q_proj in self.q_proj_sha] + key_states = [k_proj(encoder_hidden_states) for k_proj in self.k_proj_sha] + value_states = [v_proj(encoder_hidden_states) for v_proj in self.v_proj_sha] + # query_states, key_states, value_states: List of (bsz, dim_head, H, W) + + # Handle past_key_value for caching + past_key_value = kwargs.get("past_key_value") + if past_key_value is not None: + raise NotImplementedError("SHAAttention does not support kv cache yet") + + # Prepare for attention computation + attn_outputs = [] + for head_idx, (q, k, v) in enumerate(zip(query_states, key_states, value_states)): + q_flat = q.permute(0, 2, 3, 1) # (bsz, H, W, dim_head) + k_flat = k.view(bsz, 1, self.dim_head, -1) # (bsz, 1, dim_head, H_enc*W_enc) + v_flat = v.view(bsz, 1, self.dim_head, -1) # (bsz, 1, dim_head, H_enc*W_enc) + + attn_scores = torch.matmul(q_flat, k_flat) * self.scale + # attn_scores: (bsz, H, W, H_enc*W_enc) + + if attention_mask is not None: + attn_scores = attn_scores + attention_mask + + attn_probs = torch.nn.functional.softmax(attn_scores, dim=-1) + # attn_probs: (bsz, H, W, H_enc*W_enc) + + # Compute attention output + v_perm = v_flat.permute(0, 1, 3, 2) # (bsz, 1, H_enc*W_enc, dim_head) + attn_output = torch.matmul(attn_probs, v_perm) + # attn_output: (bsz, H, W, dim_head) + + attn_outputs.append(attn_output) + + # Concatenate all heads' outputs along the channel dimension + attn_output = torch.cat(attn_outputs, dim=-1) # (bsz, H, W, heads * dim_head) + + attn_output = self.to_out[0](attn_output) # (bsz, H, W, out_features) + attn_output = self.to_out[1](attn_output) # (bsz, H, W, out_features) + attn_output = attn_output.permute(0, 3, 1, 2) # (bsz,out_features, H, W) + + if self.residual_connection: + hidden_states = hidden_states + residual + + if kwargs.get("output_attentions", False): + raise NotImplementedError("output_attentions=True is not supported") + + if self.rescale_output_factor_inv != 1: + hidden_states *= self.rescale_output_factor_inv + + return attn_output + + +class PermuteLayerNorm(nn.Module): + def __init__(self, original_norm): + super().__init__() + self.original_norm = original_norm + + def forward(self, *args, **kwargs): + # Assuming the first argument is the tensor to be normalized + # Permute the tensor dimensions from (N, C, H, W) to (N, H, W, C) + permuted_args = list(args) + if len(permuted_args) > 0 and isinstance(permuted_args[0], torch.Tensor): + permuted_args[0] = permuted_args[0].permute(0, 2, 3, 1) + + # Apply the original normalization + norm_output = self.original_norm(*permuted_args, **kwargs) + + # If the output is a tuple (as in some custom norms), permute relevant tensors + if isinstance(norm_output, tuple): + # Permute the first tensor in the output tuple + norm_output = (norm_output[0].permute(0, 3, 1, 2),) + norm_output[1:] + elif isinstance(norm_output, torch.Tensor): + norm_output = norm_output.permute(0, 3, 1, 2) + + return norm_output + + +def traverse_and_replace( + model: nn.Module, + target_type: type[torch.nn.Module], + replacement_fn: Callable[[torch.nn.Module], torch.nn.Module], +): + """Recursively traverses the model to find and replace modules of a specified type. + + Args: + model (nn.Module): The model to traverse. + target_type (type): The type of modules to replace (e.g., Attention, GELU). + replacement_fn (callable): A function that takes a module instance and returns the replacement module. + + """ + for name, module in model.named_children(): + if isinstance(module, target_type): + setattr(model, name, replacement_fn(module)) + + elif isinstance(module, nn.ModuleList): + for idx in range(len(module)): + child = module[idx] + if isinstance(child, target_type): + module[idx] = replacement_fn(child) + else: + # Recursively apply to child modules + traverse_and_replace(child, target_type, replacement_fn) + else: + traverse_and_replace(module, target_type, replacement_fn) + + +def replace_attention_modules(model: nn.Module): + """Recursively traverses the model to find and replace all instances of Attention with SHAAttention, + including those nested within ModuleList containers. + + Args: + model (nn.Module): The model in which to replace Attention modules. + + """ + traverse_and_replace(model, attention_processor.Attention, lambda orig_attn: SHAAttention(orig_attn)) + + +def replace_gelu_and_approx_gelu_with_conv2d(activation_module: nn.Module) -> nn.Module: + """Replaces the projection layer in GELU and ApproximateGELU activation modules from Linear to Conv2D. + + Args: + activation_module (nn.Module): The activation module to replace. + + Returns: + nn.Module: The activation module with Conv2D projection. + + """ + assert isinstance(activation_module, GELU) or isinstance(activation_module, ApproximateGELU) + dim_in = activation_module.proj.in_features + dim_out = activation_module.proj.out_features + bias = activation_module.proj.bias is not None + + # Define Conv2d projection + conv = nn.Conv2d(in_channels=dim_in, out_channels=dim_out, kernel_size=1, bias=bias) + + # Copy weights from Linear to Conv2d + with torch.no_grad(): + conv.weight.copy_(activation_module.proj.weight.view(dim_out, dim_in, 1, 1)) + if bias: + conv.bias.copy_(activation_module.proj.bias) + + # Replace the Linear layer with Conv2d + activation_module.proj = conv + return activation_module + + +class QcGEGLU(nn.Module): + r"""A reimplemented version of the GEGLU activation function using two Conv2D layers. + This class replaces the original GEGLU's Linear projections with Conv2D projections + and eliminates the need for the chunk operation by directly computing the gate. + + Parameters: + original_geglu (GEGLU): The original GEGLU module to be replaced. + + """ + + def __init__(self, original_geglu: GEGLU): + super().__init__() + # Extract dimensions from the original GEGLU + dim_in = original_geglu.proj.in_features + dim_out = original_geglu.proj.out_features // 2 # GEGLU splits output into two parts + bias = original_geglu.proj.bias is not None + + # Define separate Conv2D layers for hidden projection and gate projection + self.hidden_proj = nn.Conv2d(in_channels=dim_in, out_channels=dim_out, kernel_size=1, bias=bias) + self.gate_proj = nn.Conv2d(in_channels=dim_in, out_channels=dim_out, kernel_size=1, bias=bias) + + # Initialize weights and biases from the original GEGLU's Linear layer + with torch.no_grad(): + # Original Linear weights shape: [dim_out*2, dim_in] + linear_weight = original_geglu.proj.weight.data # Shape: [dim_out*2, dim_in] + linear_bias = original_geglu.proj.bias.data if bias else None # Shape: [dim_out*2] + + # Assign weights to hidden_proj and gate_proj Conv2D layers + self.hidden_proj.weight.copy_(linear_weight[:dim_out, :].view(dim_out, dim_in, 1, 1)) + if bias: + self.hidden_proj.bias.copy_(linear_bias[:dim_out]) # type: ignore + + self.gate_proj.weight.copy_(linear_weight[dim_out:, :].view(dim_out, dim_in, 1, 1)) + if bias: + self.gate_proj.bias.copy_(linear_bias[dim_out:]) # type: ignore + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + return F.gelu(gate) + + def forward(self, hidden_states, *args, **kwargs): + # Project hidden states and compute gate + hidden_proj = self.hidden_proj(hidden_states) # (N, dim_out, H, W) + gate = self.gate_proj(hidden_states) # (N, dim_out, H, W) + + # Apply GELU activation to the gate + gate = self.gelu(gate) + + # Apply gating mechanism + return hidden_proj * gate + + +def replace_geglu_with_conv2d(activation_module: nn.Module) -> nn.Module: + """Replaces the original GEGLU activation module with the QcGEGLU module, + which uses two Conv2D layers and eliminates the chunk operation. + + Args: + activation_module (nn.Module): The GEGLU activation module to replace. + + Returns: + nn.Module: The QcGEGLU activation module with two Conv2D projections. + + """ + if isinstance(activation_module, GEGLU): + # Instantiate QcGEGLU with the original GEGLU module + qc_geglu = QcGEGLU(activation_module) + return qc_geglu + else: + raise TypeError(f"Unsupported activation module type for GEGLU replacement: {type(activation_module)}") + + +def replace_activations_with_conv2d(model: nn.Module): + """Recursively traverses the model to find and replace GELU, GEGLU, and ApproximateGELU activation projections + from Linear layers to Conv2D layers, ensuring compatibility with NCHW input shapes. + Also handles activations nested within ModuleList containers. + + Args: + model (nn.Module): The model in which to perform the replacement. + + """ + # Replace GELU and ApproximateGELU + traverse_and_replace(model, GELU, replace_gelu_and_approx_gelu_with_conv2d) + traverse_and_replace(model, ApproximateGELU, replace_gelu_and_approx_gelu_with_conv2d) + + # Replace GEGLU + traverse_and_replace(model, GEGLU, replace_geglu_with_conv2d) + + +def replace_feedforward_with_conv2d(feedforward_module: nn.Module) -> nn.Module: + """Replaces the nn.Linear layer in the FeedForward module with a Conv2D layer + to handle hidden_states of shape (N, C, H, W). + + Args: + feedforward_module (nn.Module): The FeedForward module to replace. + + Returns: + nn.Module: The FeedForward module with Conv2D layers instead of Linear layers. + + """ + if isinstance(feedforward_module, FeedForward): + # Create a new ModuleList to hold the modified layers + new_net = nn.ModuleList() + for module in feedforward_module.net: + if isinstance(module, nn.Linear): + # Define Conv2d projection + conv = nn.Conv2d( + in_channels=module.in_features, + out_channels=module.out_features, + kernel_size=1, + bias=module.bias is not None, + ) + # Copy weights from Linear to Conv2d + with torch.no_grad(): + conv.weight.copy_(module.weight.data.view(module.out_features, module.in_features, 1, 1)) + if module.bias is not None: + conv.bias.copy_(module.bias.data) + # Append the Conv2d layer instead of Linear + new_net.append(conv) + else: + # Append other modules (e.g., activation functions, Dropout) unchanged + new_net.append(module) + # Replace the original ModuleList with the new one containing Conv2d layers + feedforward_module.net = new_net + return feedforward_module + else: + raise TypeError(f"Unsupported module type for FeedForward replacement: {type(feedforward_module)}") + + +def replace_feedforward_modules(model: nn.Module): + # Replace FeedForward modules' Linear layers with Conv2D + traverse_and_replace(model, FeedForward, replace_feedforward_with_conv2d) + + +def replace_layer_norm_modules(model: nn.Module): + """Recursively traverses the model to find and replace all instances of + LayerNorm within BasicTransformerBlock with PermuteLayerNorm to be + compatible with optimized_operate_on_continuous_inputs. + + Args: + model (nn.Module): The model in which to replace LayerNorm modules. + + """ + + def replace_layer_norm(block: BasicTransformerBlock) -> BasicTransformerBlock: + """Replaces norm1, norm2, and norm3 within a BasicTransformerBlock. + + Args: + block (BasicTransformerBlock): The transformer block to modify. + + Returns: + BasicTransformerBlock: The modified transformer block. + + """ + block.norm1 = PermuteLayerNorm(block.norm1) + block.norm2 = PermuteLayerNorm(block.norm2) + if hasattr(block, "norm3"): + block.norm3 = PermuteLayerNorm(block.norm3) + return block + + traverse_and_replace(model, BasicTransformerBlock, replace_layer_norm) + + +def replace_transformer2d_modules(model: nn.Module): + """Recursively traverses the model to find and replace all instances of + Transformer2DModel that use linear projection, patching only those + instances with the optimized continuous‐input methods and swapping in + Conv2dLinear for proj_in / proj_out. + """ + + def _patch_instance(m: Transformer2DModel): + # bind the two optimized routines onto this instance only + m._operate_on_continuous_inputs = types.MethodType( # type: ignore + optimized_operate_on_continuous_inputs, m + ) + m._get_output_for_continuous_inputs = types.MethodType( # type: ignore + optimized_get_output_for_continuous_inputs, m + ) + + def _replace_transformer2d(m: Transformer2DModel) -> Transformer2DModel: + # 1) patch the instance methods + _patch_instance(m) + # 2) swap out the Linear‐based proj_in / proj_out for Conv2dLinear + if isinstance(m.proj_in, nn.Linear): + m.proj_in = Conv2dLinear(m.proj_in) + if isinstance(m.proj_out, nn.Linear): + m.proj_out = Conv2dLinear(m.proj_out) + return m + + traverse_and_replace(model, Transformer2DModel, _replace_transformer2d) + + +def optimized_operate_on_continuous_inputs(self, hidden_states): + """By using 4D NCHW hidden states, we can skip permutation and reshape + required in the HF implementation. + """ + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + else: + inner_dim = hidden_states.shape[1] + hidden_states = self.proj_in(hidden_states) + return hidden_states, inner_dim + + +def optimized_get_output_for_continuous_inputs(self, hidden_states, residual, batch_size, height, width, inner_dim): + """Similar to optimized_operate_on_continuous_inputs""" + hidden_states = self.proj_out(hidden_states) + return hidden_states + residual + + +def get_timestep_embedding(sample: torch.Tensor, timestep: torch.Tensor): + """Adapted from diffusers.models.get_timestep_embedding. + Removes parameters unused by our implementation and supports batching. + """ + embedding_dim = 320 # TODO: Extract from last unet layers + MAX_PERIOD = 10000 + half_dim = embedding_dim // 2 + exponent = -math.log(MAX_PERIOD) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timestep.device) + exponent = exponent / half_dim + + emb = torch.exp(exponent) + emb = timestep.float() * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + + return emb + + +def monkey_patch_model(model: UNet2DConditionModel): + """1. Apply monkey patches + 2. Apply module replacements for targeted modules whenever monkeypatch + code is too long + + Note: + - This monkeypatch is verified against diffusers==0.31.0 on stable + diffusion 1.5 + + """ + print("Monkeypatching Unet (replacing MHA with SHA attention etc)") + replace_attention_modules(model) + replace_activations_with_conv2d(model) + replace_layer_norm_modules(model) + replace_feedforward_modules(model) + replace_transformer2d_modules(model) diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/requirements_vitisai_sd.txt b/sd-legacy-stable-diffusion-v1-5/VitisAI/requirements_vitisai_sd.txt new file mode 100644 index 00000000..f4b3ab6b --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/requirements_vitisai_sd.txt @@ -0,0 +1,27 @@ +# AMD model generation +--extra-index-url=https://pypi.amd.com/ryzenai_llm/1.7.1/windows/simple + +accelerate +datasets + +# Pin onnx version +diffusers==0.35.0 +evaluate +# opentelemetry-api requires importlib-metadata<8.8.0 +importlib-metadata>=6.0,<8.8.0 + +model-generate==1.7.1 +nltk +numpy==1.26.4 +onnx==1.20.1 +onnxruntime-vitisai==1.24.3 +optimum +pandas==2.2.2 +psutil==6.0.0 +ryzenai-dynamic-dispatch==1.7.1 +ryzenai-onnx-utils==1.7.1 +sentencepiece +tabulate +torch==2.6.0 +transformers==4.56.2 +voe==1.7.1 diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/config.py b/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/config.py new file mode 100644 index 00000000..98801ebf --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/config.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +vae_sample_size = 512 +unet_sample_size = 64 +cross_attention_dim = 768 +only_conversion = True +data_dir = "quantize_data" diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/ort.py b/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/ort.py new file mode 100644 index 00000000..da36cb5d --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/ort.py @@ -0,0 +1,189 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import json +import shutil +import sys +from pathlib import Path + +import onnxruntime as ort +from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline +from olive.model import ONNXModelHandler +from onnxruntime import __version__ as OrtVersion +from packaging import version +from sd_utils import config as sd_config + +# ruff: noqa: TID252, T201 + + +def update_cuda_config(config_cuda: dict): + if version.parse(OrtVersion) < version.parse("1.17.0"): + # disable skip_group_norm fusion since there is a shape inference bug which leads to invalid models + config_cuda["passes"]["optimize_cuda"]["optimization_options"] = {"enable_skip_group_norm": False} + used_passes = {"convert", "optimize_cuda"} + for pass_name in set(config_cuda["passes"].keys()): + if pass_name not in used_passes: + config_cuda["passes"].pop(pass_name, None) + config_cuda["systems"]["local_system"]["accelerators"][0]["execution_providers"] = ["CUDAExecutionProvider"] + return config_cuda + + +def validate_args(args, provider): + ort.set_default_logger_severity(4) + if args.static_dims: + print( + "WARNING: the --static_dims option is deprecated, and static shape optimization is enabled by default. " + "Use --dynamic_dims to disable static shape optimization." + ) + + validate_ort_version(provider) + + +def validate_ort_version(provider: str): + if provider == "cuda" and version.parse(OrtVersion) < version.parse("1.17.0"): + if version.parse(OrtVersion) < version.parse("1.16.2"): + print("This script requires onnxruntime-gpu 1.16.2 or newer") + sys.exit(1) + print( + f"WARNING: onnxruntime {OrtVersion} has known issues with shape inference for SkipGroupNorm. Will disable" + " skip_group_norm fusion. onnxruntime-gpu 1.17.0 or newer is strongly recommended!" + ) + + +def save_optimized_onnx_submodel(submodel_name, provider, model_info): + footprints_file_path = Path(__file__).resolve().parents[1] / "footprints" / submodel_name / "footprint.json" + with footprints_file_path.open("r") as footprint_file: + footprints = json.load(footprint_file) + + conversion_footprint = None + optimizer_footprint = None + for footprint in footprints.values(): + from_pass = footprint["from_pass"].lower() if footprint["from_pass"] else "" + if from_pass == "OnnxConversion".lower(): + conversion_footprint = footprint + if sd_config.only_conversion: + optimizer_footprint = footprint + elif ( + from_pass == "OrtTransformersOptimization".lower() + or from_pass == "OnnxStaticQuantization".lower() + or from_pass == "EPContextBinaryGenerator".lower() + or from_pass == "DynamicToFixedShape".lower() + or from_pass == "VitisGenerateModelSD".lower() + ): + optimizer_footprint = footprint + + assert conversion_footprint + assert optimizer_footprint + + def _footprint_model_config(node): + # Olive serializes FootprintNode with alias "model_config"; accept both keys + cfg = node.get("model_config_data") or node.get("model_config") + if not cfg: + raise KeyError("Footprint node missing model_config_data / model_config") + return cfg["config"] + + unoptimized_olive_model = ONNXModelHandler(**_footprint_model_config(conversion_footprint)) + optimized_olive_model = ONNXModelHandler(**_footprint_model_config(optimizer_footprint)) + + model_info[submodel_name] = { + "unoptimized": { + "path": Path(unoptimized_olive_model.model_path), + }, + "optimized": { + "path": Path(optimized_olive_model.model_path), + }, + } + + print(f"Unoptimized Model : {model_info[submodel_name]['unoptimized']['path']}") + print(f"Optimized Model : {model_info[submodel_name]['optimized']['path']}") + + +def save_onnx_pipeline( + has_safety_checker, model_info, optimized_model_dir, unoptimized_model_dir, pipeline, submodel_names +): + # Save the unoptimized models in a directory structure that the diffusers library can load and run. + # This is optional, and the optimized models can be used directly in a custom pipeline if desired. + print("\nCreating ONNX pipeline...") + + if has_safety_checker: + safety_checker = OnnxRuntimeModel.from_pretrained(model_info["safety_checker"]["unoptimized"]["path"].parent) + else: + safety_checker = None + + onnx_pipeline = OnnxStableDiffusionPipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained(model_info["vae_encoder"]["unoptimized"]["path"].parent), + vae_decoder=OnnxRuntimeModel.from_pretrained(model_info["vae_decoder"]["unoptimized"]["path"].parent), + text_encoder=OnnxRuntimeModel.from_pretrained(model_info["text_encoder"]["unoptimized"]["path"].parent), + tokenizer=pipeline.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(model_info["unet"]["unoptimized"]["path"].parent), + scheduler=pipeline.scheduler, + safety_checker=safety_checker, + feature_extractor=pipeline.feature_extractor, + requires_safety_checker=True, + ) + + print("Saving unoptimized models...") + onnx_pipeline.save_pretrained(unoptimized_model_dir) + + # Create a copy of the unoptimized model directory, then overwrite with optimized models from the olive cache. + print("Copying optimized models...") + shutil.copytree(unoptimized_model_dir, optimized_model_dir, ignore=shutil.ignore_patterns("weights.pb")) + for submodel_name in submodel_names: + src_path = model_info[submodel_name]["optimized"]["path"] + dst_path = optimized_model_dir / submodel_name / "model.onnx" + shutil.copyfile(src_path, dst_path) + + # Copy the QNN context bin if present + src_ctx_path = Path(str(src_path).replace(".onnx", "_qnn.bin")) + if src_ctx_path.exists(): + dst_ctx_path = optimized_model_dir / submodel_name / "model_ctx_qnn.bin" + shutil.copyfile(src_ctx_path, dst_ctx_path) + + print(f"The optimized pipeline is located here: {optimized_model_dir}") + + +def get_ort_pipeline(model_dir, common_args, ort_args, guidance_scale): + ort.set_default_logger_severity(3) + + print("Loading models into ORT session...") + sess_options = ort.SessionOptions() + sess_options.enable_mem_pattern = False + + static_dims = not ort_args.dynamic_dims + batch_size = common_args.batch_size + image_size = common_args.image_size + provider = common_args.provider + vae_sample_size = sd_config.vae_sample_size + unet_sample_size = sd_config.unet_sample_size + + if static_dims: + hidden_batch_size = batch_size if (guidance_scale <= 1.0) else batch_size * 2 + # batch_size is doubled for sample & hidden state because of classifier free guidance: + # https://github.com/huggingface/diffusers/blob/46c52f9b9607e6ecb29c782c052aea313e6487b7/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py#L672 + sess_options.add_free_dimension_override_by_name("unet_sample_batch", hidden_batch_size) + sess_options.add_free_dimension_override_by_name("unet_sample_channels", 4) + sess_options.add_free_dimension_override_by_name("unet_sample_height", image_size // 8) + sess_options.add_free_dimension_override_by_name("unet_sample_width", image_size // 8) + sess_options.add_free_dimension_override_by_name("unet_time_batch", 1) + sess_options.add_free_dimension_override_by_name("unet_hidden_batch", hidden_batch_size) + sess_options.add_free_dimension_override_by_name("unet_hidden_sequence", 77) + + sess_options.add_free_dimension_override_by_name("decoder_batch", batch_size) + sess_options.add_free_dimension_override_by_name("decoder_channels", 4) + sess_options.add_free_dimension_override_by_name("decoder_height", unet_sample_size) + sess_options.add_free_dimension_override_by_name("decoder_width", unet_sample_size) + + sess_options.add_free_dimension_override_by_name("encoder_batch", batch_size) + sess_options.add_free_dimension_override_by_name("encoder_channels", 3) + sess_options.add_free_dimension_override_by_name("encoder_height", vae_sample_size) + sess_options.add_free_dimension_override_by_name("encoder_width", vae_sample_size) + + provider_map = { + "cuda": "CUDAExecutionProvider", + } + assert provider in provider_map, f"Unsupported provider: {provider}" + return OnnxStableDiffusionPipeline.from_pretrained( + model_dir, provider=provider_map[provider], sess_options=sess_options + ) diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/vai.py b/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/vai.py new file mode 100644 index 00000000..6ce29f8a --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/sd_utils/vai.py @@ -0,0 +1,227 @@ +# ------------------------------------------------------------------------- +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT +# ------------------------------------------------------------------------- +""" +NPU (VitisAI) pipeline save and load: save_npu_pipeline, get_vai_pipeline. +Unet/vae_decoder use DD session options (dd_cache, onnx_custom_ops_const_key); other submodels use OnnxRuntimeModel.from_pretrained. +""" + +import importlib +import json +import shutil +from pathlib import Path +import sd_utils.config +import os + +import onnxruntime as ort +from diffusers import OnnxRuntimeModel, OnnxStableDiffusionPipeline +from transformers import CLIPTokenizer + +def set_dd_plugins_root() -> None: + """Point DD_PLUGINS_ROOT to ryzenai_dynamic_dispatch/bin/ if not already set. + + """ + if os.environ.get("DD_PLUGINS_ROOT"): + return + try: + import importlib.util + + spec = importlib.util.find_spec("ryzenai_dynamic_dispatch") + if spec and spec.origin: + bin_dir = os.path.join(os.path.dirname(spec.origin), "bin") + if os.path.isdir(bin_dir): + os.environ["DD_PLUGINS_ROOT"] = bin_dir + except Exception: + print("Could not set DD_PLUGINS_ROOT: %s", e) + pass + + +# ruff: noqa: TID252, T201 +def update_vai_config(config: dict, provider: str, submodel_name: str): + if provider != "vitisai": + raise ValueError(f"Unsupported provider: {provider}. Only vitisai is supported.") + + used_passes = {} + if sd_utils.config.only_conversion: + used_passes = {"convert"} + config["evaluator"] = None + + if submodel_name in ("text_encoder", "vae_encoder", "safety_checker"): + used_passes = {"convert", "optimize", "dynamic_shape_to_fixed"} + elif submodel_name in ("unet", "vae_decoder"): + used_passes = {"convert", "model_generation"} + + for pass_name in set(config["passes"].keys()): + if pass_name not in used_passes: + config["passes"].pop(pass_name, None) + + config["systems"] = { + "local_system": { + "type": "LocalSystem", + "config": { + "accelerators": [ + {"execution_providers": ["CPUExecutionProvider"], "device": "cpu"}, + ], + }, + }, + } + return config + + +def save_vai_pipeline( + has_safety_checker, model_info, optimized_model_dir, unoptimized_model_dir, pipeline, submodel_names +): + """Save VitisAI pipeline: unet/vae_decoder as full dir (dd/, cache/); others as model.onnx.""" + print("\nCreating ONNX pipeline (VitisAI)...") + + # diffusers >= 0.30 expects provider_options in kwargs (kwargs.pop without default). + _cpu_onnx_kw = {"providers": ["CPUExecutionProvider"], "provider_options": [{}]} + + if has_safety_checker: + safety_checker = OnnxRuntimeModel.from_pretrained( + str(model_info["safety_checker"]["unoptimized"]["path"].parent), **_cpu_onnx_kw + ) + else: + safety_checker = None + + onnx_pipeline = OnnxStableDiffusionPipeline( + vae_encoder=OnnxRuntimeModel.from_pretrained( + str(model_info["vae_encoder"]["unoptimized"]["path"].parent), **_cpu_onnx_kw + ), + vae_decoder=OnnxRuntimeModel.from_pretrained( + str(model_info["vae_decoder"]["unoptimized"]["path"].parent), **_cpu_onnx_kw + ), + text_encoder=OnnxRuntimeModel.from_pretrained( + str(model_info["text_encoder"]["unoptimized"]["path"].parent), **_cpu_onnx_kw + ), + tokenizer=pipeline.tokenizer, + unet=OnnxRuntimeModel.from_pretrained(str(model_info["unet"]["unoptimized"]["path"].parent), **_cpu_onnx_kw), + scheduler=pipeline.scheduler, + safety_checker=safety_checker, + feature_extractor=pipeline.feature_extractor, + requires_safety_checker=True, + ) + + print("Saving unoptimized models...") + onnx_pipeline.save_pretrained(unoptimized_model_dir) + + print("Copying optimized models (VitisAI: full dir for unet/vae_decoder)...") + shutil.copytree(unoptimized_model_dir, optimized_model_dir, ignore=shutil.ignore_patterns("weights.pb")) + + NPU_SUBMODELS = ("unet", "vae_decoder") + for submodel_name in submodel_names: + src = Path(model_info[submodel_name]["optimized"]["path"]) + dst_subdir = optimized_model_dir / submodel_name + + if submodel_name in NPU_SUBMODELS: + src_dir = src.parent if src.suffix == ".onnx" else src + if not src_dir.is_dir(): + src_dir = src.parent + shutil.rmtree(dst_subdir, ignore_errors=True) + dst_subdir.mkdir(parents=True, exist_ok=True) + for item in src_dir.iterdir(): + dest_item = dst_subdir / item.name + if item.is_dir(): + shutil.copytree(item, dest_item, dirs_exist_ok=True) + else: + shutil.copy2(item, dest_item) + else: + if src.is_file() and src.suffix == ".onnx": + src_path = src + else: + src_path = (src if src.is_dir() else src.parent) / "model.onnx" + if not src_path.exists(): + raise FileNotFoundError(f"Optimized model not found: {src_path}") + shutil.copyfile(src_path, dst_subdir / "model.onnx") + + src_ctx_path = Path(str(src_path).replace(".onnx", "_qnn.bin")) + if src_ctx_path.exists(): + shutil.copyfile(src_ctx_path, dst_subdir / "model_ctx_qnn.bin") + + print(f"The optimized NPU pipeline is located here: {optimized_model_dir}") + + +def _load_npu_model(model_dir, submodel_name): + """Load unet or vae_decoder from model_dir//dd/replaced.onnx using VitisAI EP.""" + model_dir = Path(model_dir) + replaced_onnx_path = model_dir / submodel_name / "dd" / "replaced.onnx" + if not replaced_onnx_path.exists(): + raise FileNotFoundError(f"NPU optimized model not found: {replaced_onnx_path}") + sess_opts = ort.SessionOptions() + cache_dir = (replaced_onnx_path.parent / "cache").as_posix() + sess_opts.add_session_config_entry("dd_cache", cache_dir) + sess_opts.add_provider("VitisAIExecutionProvider", {}) + provider_options = [{"target": "SD"}] + session = ort.InferenceSession( + str(replaced_onnx_path), sess_options=sess_opts, providers=["VitisAIExecutionProvider"], provider_options=provider_options + ) + model = OnnxRuntimeModel(session) + config_path = model_dir / submodel_name / "config.json" + if config_path.exists(): + with config_path.open() as f: + model.config = json.load(f) + else: + model.config = {} + return model + + +def get_vai_pipeline(model_dir, common_args): + """Build pipeline for VitisAI: unet/vae_decoder from dd/replaced.onnx; text_encoder/vae_encoder on CPU.""" + set_dd_plugins_root() + ort.set_default_logger_severity(3) + model_dir = Path(model_dir) + provider = common_args.provider + if provider != "vitisai": + raise ValueError(f"Unsupported provider: {provider}. Only vitisai is supported.") + + unet = _load_npu_model(model_dir, "unet") + vae_decoder = _load_npu_model(model_dir, "vae_decoder") + print("Loading NPU pipeline (unet/vae_decoder from dd/replaced.onnx with VitisAI EP, text_encoder/vae_encoder with CPU EP)...") + + text_encoder_dir = model_dir / "text_encoder" + vae_encoder_dir = model_dir / "vae_encoder" + if not text_encoder_dir.exists() or not vae_encoder_dir.exists(): + raise FileNotFoundError(f"Missing text_encoder or vae_encoder under {model_dir}") + + # Must pass provider_options so diffusers' from_pretrained (kwargs.pop) does not raise KeyError; values must be dicts. + text_encoder = OnnxRuntimeModel.from_pretrained( + str(text_encoder_dir), providers=["CPUExecutionProvider"], provider_options=[{}] + ) + vae_encoder = OnnxRuntimeModel.from_pretrained( + str(vae_encoder_dir), providers=["CPUExecutionProvider"], provider_options=[{}] + ) + tokenizer = CLIPTokenizer.from_pretrained(str(model_dir / "tokenizer")) + + with (model_dir / "scheduler" / "scheduler_config.json").open() as f: + scheduler_name = json.load(f).get("_class_name", "PNDMScheduler") + scheduler_cls = getattr(importlib.import_module("diffusers.schedulers"), scheduler_name) + scheduler = scheduler_cls.from_pretrained(str(model_dir / "scheduler")) + + feature_extractor = None + if (model_dir / "feature_extractor").exists(): + from transformers import CLIPImageProcessor + feature_extractor = CLIPImageProcessor.from_pretrained(str(model_dir / "feature_extractor")) + + safety_checker = None + if (model_dir / "safety_checker").exists(): + try: + safety_checker = OnnxRuntimeModel.from_pretrained( + str(model_dir / "safety_checker"), + providers=["CPUExecutionProvider"], + provider_options=[{}] + ) + except Exception: + pass + + return OnnxStableDiffusionPipeline( + vae_encoder=vae_encoder, + vae_decoder=vae_decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + requires_safety_checker=False, + ) diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/stable_diffusion.py b/sd-legacy-stable-diffusion-v1-5/VitisAI/stable_diffusion.py new file mode 100644 index 00000000..cc317fee --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/stable_diffusion.py @@ -0,0 +1,408 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import argparse +import json +import shutil +import warnings +from pathlib import Path + +import numpy as np +import torch +from diffusers import DiffusionPipeline +from olive.common.utils import set_tempdir +from olive.workflows import run as olive_run +from sd_utils import config +from user_script import get_base_model_name + +# pylint: disable=redefined-outer-name +# ruff: noqa: TID252, T201 + + +def save_image(result, batch_size, provider, num_images, images_saved, image_callback=None): + passed_safety_checker = 0 + for image_index in range(batch_size): + if result.nsfw_content_detected is None or not result.nsfw_content_detected[image_index]: + passed_safety_checker += 1 + if images_saved < num_images: + output_path = f"result_{images_saved}.png" + result.images[image_index].save(output_path) + if image_callback: + image_callback(images_saved, output_path) + images_saved += 1 + print(f"Generated {output_path}") + print(f"Inference Batch End ({passed_safety_checker}/{batch_size} images).") + print("Images passed the safety checker.") + return images_saved + + +def run_inference_loop( + pipeline, + prompt, + num_images, + batch_size, + image_size, + num_inference_steps, + guidance_scale, + strength: float, + provider: str, + generator=None, + image_callback=None, + step_callback=None, +): + images_saved = 0 + + def update_steps(step, timestep, latents): + if step_callback: + step_callback((images_saved // batch_size) * num_inference_steps + step) + + while images_saved < num_images: + print(f"\nInference Batch Start (batch size = {batch_size}).") + + result = pipeline( + [prompt] * batch_size, + num_inference_steps=num_inference_steps, + callback=update_steps if step_callback else None, + height=image_size, + width=image_size, + guidance_scale=guidance_scale, + generator=generator, + ) + + images_saved = save_image(result, batch_size, provider, num_images, images_saved, image_callback) + + +def run_inference_gui( + pipeline, + prompt, + num_images, + batch_size, + image_size, + num_inference_steps, + guidance_scale, + strength, + provider, + generator, +): + import threading + import tkinter as tk + import tkinter.ttk as ttk + + from PIL import Image, ImageTk + + def update_progress_bar(total_steps_completed): + progress_bar["value"] = total_steps_completed + + def image_completed(index, path): + img = Image.open(path) + photo = ImageTk.PhotoImage(img) + gui_images[index].config(image=photo) + gui_images[index].image = photo + if index == num_images - 1: + generate_button["state"] = "normal" + + def on_generate_click(): + generate_button["state"] = "disabled" + progress_bar["value"] = 0 + threading.Thread( + target=run_inference_loop, + args=( + pipeline, + prompt_textbox.get(), + num_images, + batch_size, + image_size, + num_inference_steps, + guidance_scale, + strength, + provider, + generator, + image_completed, + update_progress_bar, + ), + ).start() + + if num_images > 9: + print("WARNING: interactive UI only supports displaying up to 9 images") + num_images = 9 + + image_rows = 1 + (num_images - 1) // 3 + image_cols = 2 if num_images == 4 else min(num_images, 3) + min_batches_required = 1 + (num_images - 1) // batch_size + + bar_height = 10 + button_width = 80 + button_height = 30 + padding = 2 + window_width = image_cols * image_size + (image_cols + 1) * padding + window_height = image_rows * image_size + (image_rows + 1) * padding + bar_height + button_height + + window = tk.Tk() + window.title("Stable Diffusion") + window.resizable(width=False, height=False) + window.geometry(f"{window_width}x{window_height}") + + gui_images = [] + for row in range(image_rows): + for col in range(image_cols): + label = tk.Label(window, width=image_size, height=image_size, background="black") + gui_images.append(label) + label.place(x=col * image_size, y=row * image_size) + + y = image_rows * image_size + (image_rows + 1) * padding + + progress_bar = ttk.Progressbar(window, value=0, maximum=num_inference_steps * min_batches_required) + progress_bar.place(x=0, y=y, height=bar_height, width=window_width) + + y += bar_height + + prompt_textbox = tk.Entry(window) + prompt_textbox.insert(tk.END, prompt) + prompt_textbox.place(x=0, y=y, width=window_width - button_width, height=button_height) + + generate_button = tk.Button(window, text="Generate", command=on_generate_click) + generate_button.place(x=window_width - button_width, y=y, width=button_width, height=button_height) + + window.mainloop() + + +def update_config_with_provider(config: dict, provider: str, submodel_name: str): + if provider == "vitisai": + from sd_utils.vai import update_vai_config + + return update_vai_config(config, provider, submodel_name) + elif provider == "cuda": + from sd_utils.ort import update_cuda_config + + return update_cuda_config(config) + elif provider == "cpu": + return config + else: + raise ValueError(f"Unsupported provider: {provider}") + + +def optimize( + common_args, + unoptimized_model_dir: Path, + optimized_model_dir: Path, +): + model_id = common_args.model_id + provider = common_args.provider + + script_dir = Path(__file__).resolve().parent + + # Clean up previously optimized models, if any. + shutil.rmtree(script_dir / "footprints", ignore_errors=True) + shutil.rmtree(unoptimized_model_dir, ignore_errors=True) + shutil.rmtree(optimized_model_dir, ignore_errors=True) + + # The model_id and base_model_id are identical when optimizing a standard stable diffusion model like + # CompVis/stable-diffusion-v1-4. These variables are only different when optimizing a LoRA variant. + base_model_id = get_base_model_name(model_id) + print(f"\nModel {base_model_id}") + + # Load the entire PyTorch pipeline to ensure all models and their configurations are downloaded and cached. + # This avoids an issue where the non-ONNX components (tokenizer, scheduler, and feature extractor) are not + # automatically cached correctly if individual models are fetched one at a time. + print("Download stable diffusion PyTorch pipeline...") + pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float32) + config.vae_sample_size = pipeline.vae.config.sample_size + config.cross_attention_dim = pipeline.unet.config.cross_attention_dim + config.unet_sample_size = pipeline.unet.config.sample_size + + model_info = {} + + submodel_names = ["vae_encoder", "vae_decoder", "unet", "text_encoder"] + + has_safety_checker = getattr(pipeline, "safety_checker", None) is not None + if has_safety_checker: + submodel_names.append("safety_checker") + + for submodel_name in submodel_names: + print(f"\nOptimizing {submodel_name}") + + olive_config = None + olive_config_path = script_dir / f"config_{submodel_name}.json" + with olive_config_path.open() as fin: + olive_config = json.load(fin) + olive_config = update_config_with_provider(olive_config, provider, submodel_name) + + if submodel_name in ("unet", "text_encoder"): + olive_config["input_model"]["model_path"] = model_id + else: + # Only the unet & text encoder are affected by LoRA, so it's better to use the base model ID for + # other models: the Olive cache is based on the JSON config, and two LoRA variants with the same + # base model ID should be able to reuse previously optimized copies. + olive_config["input_model"]["model_path"] = base_model_id + + olive_run(olive_config) + + from sd_utils.ort import save_optimized_onnx_submodel + + save_optimized_onnx_submodel(submodel_name, provider, model_info) + + if provider == "vitisai": + from sd_utils.vai import save_vai_pipeline + + save_vai_pipeline( + has_safety_checker, model_info, optimized_model_dir, unoptimized_model_dir, pipeline, submodel_names + ) + else: + from sd_utils.ort import save_onnx_pipeline + + save_onnx_pipeline( + has_safety_checker, model_info, optimized_model_dir, unoptimized_model_dir, pipeline, submodel_names + ) + + return model_info + + +def parse_common_args(raw_args): + parser = argparse.ArgumentParser("Common arguments") + + parser.add_argument("--model_id", default="CompVis/stable-diffusion-v1-4", type=str) + parser.add_argument( + "--provider", + default="vitisai", + type=str, + choices=["vitisai", "cpu", "cuda"], + help="Execution provider to use", + ) + parser.add_argument("--optimize", action="store_true", help="Runs the optimization step") + parser.add_argument("--clean_cache", action="store_true", help="Deletes the Olive cache") + parser.add_argument("--test_unoptimized", action="store_true", help="Use unoptimized model for inference") + parser.add_argument("--batch_size", default=1, type=int, help="Number of images to generate per batch") + parser.add_argument( + "--prompt", + default=( + "castle surrounded by water and nature, village, volumetric lighting, photorealistic, " + "detailed and intricate, fantasy, epic cinematic shot, mountains, 8k ultra hd" + ), + type=str, + ) + parser.add_argument( + "--guidance_scale", + default=7.5, + type=float, + help="Guidance scale as defined in Classifier-Free Diffusion Guidance", + ) + parser.add_argument("--num_images", default=1, type=int, help="Number of images to generate") + parser.add_argument("--num_inference_steps", default=50, type=int, help="Number of steps in diffusion process") + parser.add_argument("--interactive", action="store_true", help="Run with a GUI") + parser.add_argument("--tempdir", default=None, type=str, help="Root directory for tempfile directories and files") + parser.add_argument( + "--strength", + default=1.0, + type=float, + help=( + "Value between 0.0 and 1.0, that controls the amount of noise that is added to the input image. " + "Values that approach 1.0 enable lots of variations but will also produce images " + "that are not semantically consistent with the input." + ), + ) + parser.add_argument("--image_size", default=512, type=int, help="Width and height of the images to generate") + parser.add_argument( + "--seed", + default=None, + type=int, + help="The seed to give to the generator to generate deterministic results.", + ) + + return parser.parse_known_args(raw_args) + + +def parse_ort_args(raw_args): + parser = argparse.ArgumentParser("ONNX Runtime arguments") + + parser.add_argument( + "--static_dims", + action="store_true", + help="DEPRECATED (now enabled by default). Use --dynamic_dims to disable static_dims.", + ) + parser.add_argument("--dynamic_dims", action="store_true", help="Disable static shape optimization") + + return parser.parse_known_args(raw_args) + + +def main(raw_args=None): + common_args, extra_args = parse_common_args(raw_args) + + provider = common_args.provider + model_id = common_args.model_id + + script_dir = Path(__file__).resolve().parent + unoptimized_model_dir = script_dir / "model" / "unoptimized" / model_id + optimized_model_dir = script_dir / "model" / f"optimized-{provider}" / model_id + + if common_args.clean_cache: + shutil.rmtree(script_dir / "cache", ignore_errors=True) + + guidance_scale = common_args.guidance_scale + + if model_id == "stabilityai/sd-turbo" and guidance_scale > 0: + guidance_scale = 0.0 + print(f"WARNING: Classifier free guidance has been forcefully disabled since {model_id} doesn't support it.") + + ort_args, extra_args = parse_ort_args(extra_args) + + if common_args.optimize or not optimized_model_dir.exists(): + set_tempdir(common_args.tempdir) + + # TODO(jstoecker): clean up warning filter (mostly during conversion from torch to ONNX) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from sd_utils.ort import validate_args + + validate_args(ort_args, provider) + optimize( + common_args, + unoptimized_model_dir, + optimized_model_dir, + ) + + generator = None if common_args.seed is None else np.random.RandomState(seed=common_args.seed) + + if not common_args.optimize: + model_dir = unoptimized_model_dir if common_args.test_unoptimized else optimized_model_dir + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if provider == "vitisai": + from sd_utils.vai import get_vai_pipeline + + pipeline = get_vai_pipeline(model_dir, common_args) + else: + from sd_utils.ort import get_ort_pipeline + + pipeline = get_ort_pipeline(model_dir, common_args, ort_args, guidance_scale) + + if common_args.interactive: + run_inference_gui( + pipeline, + common_args.prompt, + common_args.num_images, + common_args.batch_size, + common_args.image_size, + common_args.num_inference_steps, + guidance_scale, + common_args.strength, + provider=provider, + generator=generator, + ) + else: + run_inference_loop( + pipeline, + common_args.prompt, + common_args.num_images, + common_args.batch_size, + common_args.image_size, + common_args.num_inference_steps, + guidance_scale, + common_args.strength, + provider=provider, + generator=generator, + ) + + +if __name__ == "__main__": + main() diff --git a/sd-legacy-stable-diffusion-v1-5/VitisAI/user_script.py b/sd-legacy-stable-diffusion-v1-5/VitisAI/user_script.py new file mode 100644 index 00000000..fe647855 --- /dev/null +++ b/sd-legacy-stable-diffusion-v1-5/VitisAI/user_script.py @@ -0,0 +1,382 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +import random + +import numpy as np +import torch +from diffusers import AutoencoderKL, UNet2DConditionModel +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from huggingface_hub import model_info +from model_adaptations import monkey_patch_model +from olive.data.registry import Registry +from sd_utils import config +from transformers.models.clip.modeling_clip import CLIPTextModel + +# ruff: noqa: T201 + +# Generated data helpers + + +class BaseDataLoader: + def __init__(self, total): + self.data = [] + self.total = total + self.data_folders = [config.data_dir / f.name for f in os.scandir(config.data_dir) if f.is_dir()] + self.data_folders.sort() + + def __getitem__(self, idx): + if idx >= len(self.data) or idx >= self.total: + raise StopIteration + # print(f"Process data {idx}") + return self.data[idx] + + def load(self, file): + self.data.append({key: torch.from_numpy(value) for key, value in np.load(file).items()}) + + def finish_load(self): + if len(self.data) > self.total: + self.data = random.sample(self.data, self.total) + + +class UnetGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + for f in self.data_folders: + i = 0 + while True: + file = f / f"{i}_unet_input.npz" + if not os.path.exists(file): + break + self.load(file) + file = f / f"{i}_unet_input_neg.npz" + if os.path.exists(file): + self.load(file) + i += 1 + self.finish_load() + + +class TextEncoderGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + for f in self.data_folders: + self.load(f / "text_inputs.npz") + self.load(f / "uncond_input.npz") + self.finish_load() + + +class VaeDecoderGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + for f in self.data_folders: + self.load(f / "latent.npz") + self.finish_load() + + +class VaeEncoderGeneratedDataLoader(BaseDataLoader): + def __init__(self, total): + super().__init__(total) + for f in self.data_folders: + self.load(f / "output_img.npz") + self.finish_load() + + +# Helper latency-only dataloader that creates random tensors with no label +class RandomDataLoader: + def __init__(self, create_inputs_func, batch_size, torch_dtype): + self.create_input_func = create_inputs_func + self.batch_size = batch_size + self.torch_dtype = torch_dtype + + def __getitem__(self, idx): + label = None + return self.create_input_func(self.batch_size, self.torch_dtype), label + + +def get_base_model_name(model_name): + return model_info(model_name).cardData.get("base_model") or model_name + + +def is_lora_model(model_name): + # TODO(jstoecker): might be a better way to detect (e.g. presence of LORA weights file) + return model_name != get_base_model_name(model_name) + + +# Merges LoRA weights into the layers of a base model +def merge_lora_weights(base_model, lora_model_id, submodel_name="unet", scale=1.0): + import inspect + from collections import defaultdict + from functools import reduce + + try: + from diffusers.loaders import LORA_WEIGHT_NAME + except ImportError: + # moved in version 0.24.0 + from diffusers.loaders.lora import LORA_WEIGHT_NAME + from diffusers.models.attention_processor import LoRAAttnProcessor + from diffusers.utils.hub_utils import _get_model_file + + parameters = inspect.signature(_get_model_file).parameters + + kwargs = {} + if "use_auth_token" in parameters: + kwargs["use_auth_token"] = None + elif "token" in parameters: + kwargs["token"] = None + + # Load LoRA weights + model_file = _get_model_file( + lora_model_id, + weights_name=LORA_WEIGHT_NAME, + cache_dir=None, + force_download=False, + resume_download=False, + proxies=None, + local_files_only=False, + revision=None, + subfolder=None, + user_agent={ + "file_type": "attn_procs_weights", + "framework": "pytorch", + }, + **kwargs, + ) + lora_state_dict = torch.load(model_file, map_location="cpu") + + # All keys in the LoRA state dictionary should have 'lora' somewhere in the string. + keys = list(lora_state_dict.keys()) + assert all("lora" in k for k in keys) + + if all(key.startswith(submodel_name) for key in keys): + # New format (https://github.com/huggingface/diffusers/pull/2918) supports LoRA weights in both the + # unet and text encoder where keys are prefixed with 'unet' or 'text_encoder', respectively. + submodel_state_dict = {k: v for k, v in lora_state_dict.items() if k.startswith(submodel_name)} + else: + # Old format. Keys will not have any prefix. This only applies to unet, so exit early if this is + # optimizing the text encoder. + if submodel_name != "unet": + return + submodel_state_dict = lora_state_dict + + # Group LoRA weights into attention processors + attn_processors = {} + lora_grouped_dict = defaultdict(dict) + for key, value in submodel_state_dict.items(): + attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) + lora_grouped_dict[attn_processor_key][sub_key] = value + + for key, value_dict in lora_grouped_dict.items(): + rank = value_dict["to_k_lora.down.weight"].shape[0] + cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1] + hidden_size = value_dict["to_k_lora.up.weight"].shape[0] + + attn_processors[key] = LoRAAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + ) + attn_processors[key].load_state_dict(value_dict) + + # Merge LoRA attention processor weights into existing Q/K/V/Out weights + for name, proc in attn_processors.items(): + attention_name = name[: -len(".processor")] + attention = reduce(getattr, attention_name.split(sep="."), base_model) + attention.to_q.weight.data += scale * torch.mm(proc.to_q_lora.up.weight, proc.to_q_lora.down.weight) + attention.to_k.weight.data += scale * torch.mm(proc.to_k_lora.up.weight, proc.to_k_lora.down.weight) + attention.to_v.weight.data += scale * torch.mm(proc.to_v_lora.up.weight, proc.to_v_lora.down.weight) + attention.to_out[0].weight.data += scale * torch.mm(proc.to_out_lora.up.weight, proc.to_out_lora.down.weight) + + +# ----------------------------------------------------------------------------- +# TEXT ENCODER +# ----------------------------------------------------------------------------- + + +def text_encoder_inputs(batch_size, torch_dtype): + return torch.zeros((batch_size, 77), dtype=torch_dtype) + + +def text_encoder_load(model_name): + base_model_id = get_base_model_name(model_name) + model = CLIPTextModel.from_pretrained(base_model_id, subfolder="text_encoder") + if is_lora_model(model_name): + merge_lora_weights(model, model_name, "text_encoder") + return model + + +def text_encoder_conversion_inputs(model=None): + return text_encoder_inputs(1, torch.int32) + + +@Registry.register_dataloader() +def text_encoder_data_loader(dataset, batch_size, *args, **kwargs): + return RandomDataLoader(text_encoder_inputs, batch_size, torch.int32) + + +@Registry.register_dataloader() +def text_encoder_quantize_data_loader(dataset, data_num, *args, **kwargs): + return TextEncoderGeneratedDataLoader(data_num) + + +# ----------------------------------------------------------------------------- +# UNET +# ----------------------------------------------------------------------------- + + +def unet_inputs(batch_size, torch_dtype, is_conversion_inputs=False): + # TODO(jstoecker): Rename onnx::Concat_4 to text_embeds and onnx::Shape_5 to time_ids + inputs = { + "sample": torch.rand((batch_size, 4, config.unet_sample_size, config.unet_sample_size), dtype=torch_dtype), + "timestep": torch.rand((batch_size,), dtype=torch_dtype), + "encoder_hidden_states": torch.rand((batch_size, 77, config.cross_attention_dim), dtype=torch_dtype), + } + + # use as kwargs since they won't be in the correct position if passed along with the tuple of inputs + kwargs = { + "return_dict": False, + } + if is_conversion_inputs: + inputs["additional_inputs"] = { + **kwargs, + "added_cond_kwargs": { + "text_embeds": torch.rand((1, 1280), dtype=torch_dtype), + "time_ids": torch.rand((1, 5), dtype=torch_dtype), + }, + } + else: + inputs.update(kwargs) + inputs["onnx::Concat_4"] = torch.rand((1, 1280), dtype=torch_dtype) + inputs["onnx::Shape_5"] = torch.rand((1, 5), dtype=torch_dtype) + + return inputs + + +def get_unet_ov_example_input(): + encoder_hidden_state = torch.ones((2, 77, 768)) + latents_shape = (2, 4, 512 // 8, 512 // 8) + latents = torch.randn(latents_shape) + t = torch.from_numpy(np.array(1, dtype=float)) + return (latents, t, encoder_hidden_state) + + +def unet_load(model_name): + base_model_id = get_base_model_name(model_name) + model = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet") + if is_lora_model(model_name): + merge_lora_weights(model, model_name, "unet") + return model + + +def unet_load_qnn(model_name): + base_model_id = get_base_model_name(model_name) + model = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet") + if is_lora_model(model_name): + merge_lora_weights(model, model_name, "unet") + monkey_patch_model(model) + return model + + +def unet_conversion_inputs(model=None): + return tuple(unet_inputs(1, torch.float32, True).values()) + + +@Registry.register_dataloader() +def unet_data_loader(dataset, batch_size, *args, **kwargs): + return RandomDataLoader(unet_inputs, batch_size, torch.float16) + + +@Registry.register_dataloader() +def unet_quantize_data_loader(dataset, data_num, *args, **kwargs): + return UnetGeneratedDataLoader(data_num) + + +# ----------------------------------------------------------------------------- +# VAE ENCODER +# ----------------------------------------------------------------------------- + + +def vae_encoder_inputs(batch_size, torch_dtype): + return {"sample": torch.rand((batch_size, 3, config.vae_sample_size, config.vae_sample_size), dtype=torch_dtype)} + + +def vae_encoder_load(model_name): + base_model_id = get_base_model_name(model_name) + model = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae") + model.forward = lambda sample: model.encode(sample)[0].sample() + return model + + +def vae_encoder_conversion_inputs(model=None): + return tuple(vae_encoder_inputs(1, torch.float32).values()) + + +@Registry.register_dataloader() +def vae_encoder_data_loader(dataset, batch_size, *args, **kwargs): + return RandomDataLoader(vae_encoder_inputs, batch_size, torch.float16) + + +@Registry.register_dataloader() +def vae_encoder_quantize_data_loader(dataset, data_num, *args, **kwargs): + return VaeEncoderGeneratedDataLoader(data_num) + + +# ----------------------------------------------------------------------------- +# VAE DECODER +# ----------------------------------------------------------------------------- + + +def vae_decoder_inputs(batch_size, torch_dtype): + return { + "latent_sample": torch.rand( + (batch_size, 4, config.unet_sample_size, config.unet_sample_size), dtype=torch_dtype + ) + } + + +def vae_decoder_load(model_name): + base_model_id = get_base_model_name(model_name) + model = AutoencoderKL.from_pretrained(base_model_id, subfolder="vae") + model.forward = model.decode + return model + + +def vae_decoder_conversion_inputs(model=None): + return tuple(vae_decoder_inputs(1, torch.float32).values()) + + +@Registry.register_dataloader() +def vae_decoder_data_loader(dataset, batch_size, *args, **kwargs): + return RandomDataLoader(vae_decoder_inputs, batch_size, torch.float16) + + +@Registry.register_dataloader() +def vae_decoder_quantize_data_loader(dataset, data_num, *args, **kwargs): + return VaeDecoderGeneratedDataLoader(data_num) + + +# ----------------------------------------------------------------------------- +# SAFETY CHECKER +# ----------------------------------------------------------------------------- + + +def safety_checker_inputs(batch_size, torch_dtype): + return { + "clip_input": torch.rand((batch_size, 3, 224, 224), dtype=torch_dtype), + "images": torch.rand((batch_size, config.vae_sample_size, config.vae_sample_size, 3), dtype=torch_dtype), + } + + +def safety_checker_load(model_name): + base_model_id = get_base_model_name(model_name) + model = StableDiffusionSafetyChecker.from_pretrained(base_model_id, subfolder="safety_checker") + model.forward = model.forward_onnx + return model + + +def safety_checker_conversion_inputs(model=None): + return tuple(safety_checker_inputs(1, torch.float32).values()) + + +@Registry.register_dataloader() +def safety_checker_data_loader(dataset, batch_size, *args, **kwargs): + return RandomDataLoader(safety_checker_inputs, batch_size, torch.float16)