diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..4fc94f3541da4fc063bd9b4574f23fbb8afebb31 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/sculpture.png filter=lfs diff=lfs merge=lfs -text +assets/teaser_figure.png filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md new file mode 100644 index 0000000000000000000000000000000000000000..49b240877d130b0ac8144fd91b08054efb6e78b6 --- /dev/null +++ b/ORIGINAL_README.md @@ -0,0 +1,191 @@ +
+

InstantIR: Blind Image Restoration with
Instant Generative Reference

+ +[**Jen-Yuan Huang**](https://jy-joy.github.io)1 2, [**Haofan Wang**](https://haofanwang.github.io/)2, [**Qixun Wang**](https://github.com/wangqixun)2, [**Xu Bai**](https://huggingface.co./baymin0220)2, Hao Ai2, Peng Xing2, [**Jen-Tse Huang**](https://penguinnnnn.github.io)3
+ +1Peking University ยท 2InstantX Team ยท 3The Chinese University of Hong Kong + + + + + + + + + + +
+ +**InstantIR** is a novel single-image restoration model designed to resurrect your damaged images, delivering extrem-quality yet realistic details. You can further boost **InstantIR** performance with additional text prompts, even achieve customized editing! + + + + + + +## ๐Ÿ“ข News +- **11/03/2024** ๐Ÿ”ฅ We provide a Gradio launching script for InstantIR, you can now deploy it on your local machine! +- **11/02/2024** ๐Ÿ”ฅ InstantIR is now compatitble with ๐Ÿงจ `diffusers`, you can utilize features from this fascinating package! +- **10/15/2024** ๐Ÿ”ฅ Code and model released! + +## ๐Ÿ“ TODOs: +- [ ] Launch online demo +- [x] Remove dependency on local `diffusers` +- [x] Gradio launching script + +## โœจ Usage + + +### Quick start +#### 1. Clone this repo and setting up environment +```sh +git clone https://github.com/JY-Joy/InstantIR.git +cd InstantIR +conda create -n instantir python=3.9 -y +conda activate instantir +pip install -r requirements.txt +``` + +#### 2. Download pre-trained models + +InstantIR is built on SDXL and DINOv2. You can download them either directly from ๐Ÿค— huggingface or using Python package. + +| ๐Ÿค— link | Python command +| :--- | :---------- +|[SDXL](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0) | `hf_hub_download(repo_id="stabilityai/stable-diffusion-xl-base-1.0")` +|[facebook/dinov2-large](https://huggingface.co./facebook/dinov2-large) | `hf_hub_download(repo_id="facebook/dinov2-large")` +|[InstantX/InstantIR](https://huggingface.co./InstantX/InstantIR) | `hf_hub_download(repo_id="InstantX/InstantIR")` + +Note: Make sure to import the package first with `from huggingface_hub import hf_hub_download` if you are using Python script. + +#### 3. Inference + +You can run InstantIR inference using `infer.sh` with the following arguments specified. + +```sh +infer.sh \ + --sdxl_path \ + --vision_encoder_path \ + --instantir_path \ + --test_path \ + --out_path +``` + +See `infer.py` for more config options. + +#### 4. Using tips + +InstantIR is powerful, but with your help it can do better. InstantIR's flexible pipeline makes it tunable to a large extent. Here are some tips we found particularly useful for various cases you may encounter: +- **Over-smoothing**: reduce `--cfg` to 3.0๏ฝž5.0. Higher CFG scales can sometimes rigid lines or lack of details. +- **Low fidelity**: set `--preview_start` to 0.1~0.4 to preserve fidelity from inputs. The previewer can yield misleading references when input latent is too noisy. In such cases, we suggest to disable the previewer at early timesteps. +- **Local distortions**: set `--creative_start` to 0.6~0.8. This will let InstantIR render freely in the late diffusion process, where the high-frequency details are generated. Smaller `--creative_start` spares more spaces for creative restoration, but will diminish fidelity. +- **Faster inference**: higher `--preview_start` and lower `--creative_start` can both reduce computational costs and accelerate InstantIR inference. + +> [!CAUTION] +> These features are training-free and thus experimental. If you would like to try, we suggest to tune these parameters case-by-case. + +### Use InstantIR with diffusers ๐Ÿงจ + +InstantIR is fully compatible with `diffusers` and is supported by all those powerful features in this package. You can directly load InstantIR via `diffusers` snippet: + +```py +# !pip install diffusers opencv-python transformers accelerate +import torch +from PIL import Image + +from diffusers import DDPMScheduler +from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler + +from module.ip_adapter.utils import load_adapter_to_pipe +from pipelines.sdxl_instantir import InstantIRPipeline + +# suppose you have InstantIR weights under ./models +instantir_path = f'./models' + +# load pretrained models +pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16) + +# load adapter +load_adapter_to_pipe( + pipe, + f"{instantir_path}/adapter.pt", + image_encoder_or_path = 'facebook/dinov2-large', +) + +# load previewer lora +pipe.prepare_previewers(instantir_path) +pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") +lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) + +# load aggregator weights +pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt") +pipe.aggregator.load_state_dict(pretrained_state_dict) + +# send to GPU and fp16 +pipe.to(device='cuda', dtype=torch.float16) +pipe.aggregator.to(device='cuda', dtype=torch.float16) +``` + +Then, you just need to call the `pipe` and InstantIR will handle your image! + +```py +# load a broken image +low_quality_image = Image.open('./assets/sculpture.png').convert("RGB") + +# InstantIR restoration +image = pipe( + image=low_quality_image, + previewer_scheduler=lcm_scheduler, +).images[0] +``` + +### Deploy local gradio demo + +We provide a python script to launch a local gradio demo of InstantIR, with basic and some advanced features implemented. Start by running the following command in your terminal: + +```sh +INSTANTIR_PATH= python gradio_demo/app.py +``` + +Then, visit your local demo via your browser at `http://localhost:7860`. + + +## โš™๏ธ Training + +### Prepare data + +InstantIR is trained on [DIV2K](https://www.kaggle.com/datasets/joe1995/div2k-dataset), [Flickr2K](https://www.kaggle.com/datasets/daehoyang/flickr2k), [LSDIR](https://data.vision.ee.ethz.ch/yawli/index.html) and [FFHQ](https://www.kaggle.com/datasets/rahulbhalley/ffhq-1024x1024). We adopt dataset weighting to balance the distribution. You can config their weights in ```config_files/IR_dataset.yaml```. Download these training sets and put them under a same directory, which will be used in the following training configurations. + +### Two-stage training +As described in our paper, the training of InstantIR is conducted in two stages. We provide corresponding `.sh` training scripts for each stage. Make sure you have the following arguments adapted to your own use case: + +| Argument | Value +| :--- | :---------- +| `--pretrained_model_name_or_path` | path to your SDXL folder +| `--feature_extractor_path` | path to your DINOv2 folder +| `--train_data_dir` | your training data directory +| `--output_dir` | path to save model weights +| `--logging_dir` | path to save logs +| `` | number of available GPUs + +Other training hyperparameters we used in our experiments are provided in the corresponding `.sh` scripts. You can tune them according to your own needs. + +## ๐Ÿ‘ Acknowledgment +Our work is sponsored by [HuggingFace](https://huggingface.co.) and [fal.ai](https://fal.ai). + +## ๐ŸŽ“ Citation + +If InstantIR is helpful to your work, please cite our paper via: + +``` +@article{huang2024instantir, + title={InstantIR: Blind Image Restoration with Instant Generative Reference}, + author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse}, + journal={arXiv preprint arXiv:2410.06551}, + year={2024} +} +``` diff --git a/assets/.DS_Store b/assets/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/assets/.DS_Store differ diff --git a/assets/Konan.png b/assets/Konan.png new file mode 100644 index 0000000000000000000000000000000000000000..24b44e20ff41032c4037790f88b7ca20f1e0adb6 Binary files /dev/null and b/assets/Konan.png differ diff --git a/assets/Naruto.png b/assets/Naruto.png new file mode 100644 index 0000000000000000000000000000000000000000..7eb3836d04f19864147b7d761c68acd66ca4b0a2 Binary files /dev/null and b/assets/Naruto.png differ diff --git a/assets/cottage.png b/assets/cottage.png new file mode 100644 index 0000000000000000000000000000000000000000..dcefa84eed0fac611b273b1efa58af5bd52f1e00 Binary files /dev/null and b/assets/cottage.png differ diff --git a/assets/dog.png b/assets/dog.png new file mode 100644 index 0000000000000000000000000000000000000000..58125cfdf3c8ee64c1e3930d7505f2c74160df80 Binary files /dev/null and b/assets/dog.png differ diff --git a/assets/lady.png b/assets/lady.png new file mode 100644 index 0000000000000000000000000000000000000000..a3df98b191dcaf2793d9364f1f525acbce165d7b Binary files /dev/null and b/assets/lady.png differ diff --git a/assets/man.png b/assets/man.png new file mode 100644 index 0000000000000000000000000000000000000000..4a738b18f659ac7a23f03c3e3388f3f450382b2e Binary files /dev/null and b/assets/man.png differ diff --git a/assets/panda.png b/assets/panda.png new file mode 100644 index 0000000000000000000000000000000000000000..f7c47480bc007be8a130dc4f24e13201373f868f Binary files /dev/null and b/assets/panda.png differ diff --git a/assets/sculpture.png b/assets/sculpture.png new file mode 100644 index 0000000000000000000000000000000000000000..aa92a295a0884976629cd6796ddbe3d9d431827c --- /dev/null +++ b/assets/sculpture.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c4af7c3dc545d2f48b0ac2afef69bd7b1f0489ced7ea452d92f69ff5a9d4019 +size 1200996 diff --git a/assets/teaser_figure.png b/assets/teaser_figure.png new file mode 100644 index 0000000000000000000000000000000000000000..cbe94f8394b912afd3e639ff49646e479583af8d --- /dev/null +++ b/assets/teaser_figure.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9c7e8e59af17516d11e21c5bc56b48824a3875c81e4afb181a5c3facc217d08 +size 16937178 diff --git a/config_files/IR_dataset.yaml b/config_files/IR_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..acfaa864ca8674082d74d268ffe266095608b6d6 --- /dev/null +++ b/config_files/IR_dataset.yaml @@ -0,0 +1,9 @@ +datasets: + - dataset_folder: 'ffhq' + dataset_weight: 0.1 + - dataset_folder: 'DIV2K' + dataset_weight: 0.3 + - dataset_folder: 'LSDIR' + dataset_weight: 0.3 + - dataset_folder: 'Flickr2K' + dataset_weight: 0.1 diff --git a/config_files/losses.yaml b/config_files/losses.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8a650909b1bae32cbbb7b82817a04e1750850d7e --- /dev/null +++ b/config_files/losses.yaml @@ -0,0 +1,19 @@ +diffusion_losses: +- name: L2Loss + weight: 1 +lcm_losses: +- name: HuberLoss + weight: 1 +# - name: DINOLoss +# weight: 1e-3 +# - name: L2Loss +# weight: 5e-2 +# - name: LPIPSLoss +# weight: 1e-3 +# - name: DreamSIMLoss +# weight: 1e-3 +# - name: IDLoss +# weight: 1e-3 +# visualize_every_k: 50 +# init_params: +# pretrained_arcface_path: /home/dcor/orlichter/consistency_encoder_private/pretrained_models/model_ir_se50.pth \ No newline at end of file diff --git a/config_files/val_dataset.yaml b/config_files/val_dataset.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9674bb2f16ae220d6749f7294e1c05a5f0b453e6 --- /dev/null +++ b/config_files/val_dataset.yaml @@ -0,0 +1,7 @@ +datasets: + - dataset_folder: 'ffhq' + dataset_weight: 0.1 + - dataset_folder: 'DIV2K' + dataset_weight: 0.45 + - dataset_folder: 'LSDIR' + dataset_weight: 0.45 diff --git a/data/data_config.py b/data/data_config.py new file mode 100644 index 0000000000000000000000000000000000000000..2536debf191ecbbe05eeebb1778141ed73456b47 --- /dev/null +++ b/data/data_config.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass, field +from typing import Optional, List + + +@dataclass +class SingleDataConfig: + dataset_folder: str + imagefolder: bool = True + dataset_weight: float = 1.0 # Not used yet + +@dataclass +class DataConfig: + datasets: List[SingleDataConfig] + val_dataset: Optional[SingleDataConfig] = None diff --git a/data/dataset.py b/data/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..17d84cdb3571b0f9b5de9cf6ee9a9b36d5723e6f --- /dev/null +++ b/data/dataset.py @@ -0,0 +1,202 @@ +from pathlib import Path +from typing import Optional + +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +import json +import random +from facenet_pytorch import MTCNN +import torch + +from utils.utils import extract_faces_and_landmarks, REFERNCE_FACIAL_POINTS_RELATIVE + +def load_image(image_path: str) -> Image: + image = Image.open(image_path) + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + return image + + +class ImageDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + metadata_path: Optional[str] = None, + prompt_in_filename=False, + use_only_vanilla_for_encoder=False, + concept_placeholder='a face', + size=1024, + center_crop=False, + aug_images=False, + use_only_decoder_prompts=False, + crop_head_for_encoder_image=False, + random_target_prob=0.0, + ): + self.mtcnn = MTCNN(device='cuda:0') + self.mtcnn.forward = self.mtcnn.detect + resize_factor = 1.3 + self.resized_reference_points = REFERNCE_FACIAL_POINTS_RELATIVE / resize_factor + (resize_factor - 1) / (2 * resize_factor) + self.size = size + self.center_crop = center_crop + self.concept_placeholder = concept_placeholder + self.prompt_in_filename = prompt_in_filename + self.aug_images = aug_images + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.name_to_label = None + self.crop_head_for_encoder_image = crop_head_for_encoder_image + self.random_target_prob = random_target_prob + + self.use_only_decoder_prompts = use_only_decoder_prompts + + self.instance_data_root = Path(instance_data_root) + + if not self.instance_data_root.exists(): + raise ValueError(f"Instance images root {self.instance_data_root} doesn't exist.") + + if metadata_path is not None: + with open(metadata_path, 'r') as f: + self.name_to_label = json.load(f) # dict of filename: label + # Create a reversed mapping + self.label_to_names = {} + for name, label in self.name_to_label.items(): + if use_only_vanilla_for_encoder and 'vanilla' not in name: + continue + if label not in self.label_to_names: + self.label_to_names[label] = [] + self.label_to_names[label].append(name) + self.all_paths = [self.instance_data_root / filename for filename in self.name_to_label.keys()] + + # Verify all paths exist + n_all_paths = len(self.all_paths) + self.all_paths = [path for path in self.all_paths if path.exists()] + print(f'Found {len(self.all_paths)} out of {n_all_paths} paths.') + else: + self.all_paths = [path for path in list(Path(instance_data_root).glob('**/*')) if + path.suffix.lower() in [".png", ".jpg", ".jpeg"]] + # Sort by name so that order for validation remains the same across runs + self.all_paths = sorted(self.all_paths, key=lambda x: x.stem) + + self.custom_instance_prompts = None + + self._length = len(self.all_paths) + + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + if self.prompt_in_filename: + self.prompts_set = set([self._path_to_prompt(path) for path in self.all_paths]) + else: + self.prompts_set = set([self.instance_prompt]) + + if self.aug_images: + self.aug_transforms = transforms.Compose( + [ + transforms.RandomResizedCrop(size, scale=(0.8, 1.0), ratio=(1.0, 1.0)), + transforms.RandomHorizontalFlip(p=0.5) + ] + ) + + def __len__(self): + return self._length + + def _path_to_prompt(self, path): + # Remove the extension and seed + split_path = path.stem.split('_') + while split_path[-1].isnumeric(): + split_path = split_path[:-1] + + prompt = ' '.join(split_path) + # Replace placeholder in prompt with training placeholder + prompt = prompt.replace('conceptname', self.concept_placeholder) + return prompt + + def __getitem__(self, index): + example = {} + instance_path = self.all_paths[index] + instance_image = load_image(instance_path) + example["instance_images"] = self.image_transforms(instance_image) + if self.prompt_in_filename: + example["instance_prompt"] = self._path_to_prompt(instance_path) + else: + example["instance_prompt"] = self.instance_prompt + + if self.name_to_label is None: + # If no labels, simply take the same image but with different augmentation + example["encoder_images"] = self.aug_transforms(example["instance_images"]) if self.aug_images else example["instance_images"] + example["encoder_prompt"] = example["instance_prompt"] + else: + # Randomly select another image with the same label + instance_name = str(instance_path.relative_to(self.instance_data_root)) + instance_label = self.name_to_label[instance_name] + label_set = set(self.label_to_names[instance_label]) + if len(label_set) == 1: + # We are not supposed to have only one image per label, but just in case + encoder_image_name = instance_name + print(f'WARNING: Only one image for label {instance_label}.') + else: + encoder_image_name = random.choice(list(label_set - {instance_name})) + encoder_image = load_image(self.instance_data_root / encoder_image_name) + example["encoder_images"] = self.image_transforms(encoder_image) + + if self.prompt_in_filename: + example["encoder_prompt"] = self._path_to_prompt(self.instance_data_root / encoder_image_name) + else: + example["encoder_prompt"] = self.instance_prompt + + if self.crop_head_for_encoder_image: + example["encoder_images"] = extract_faces_and_landmarks(example["encoder_images"][None], self.size, self.mtcnn, self.resized_reference_points)[0][0] + example["encoder_prompt"] = example["encoder_prompt"].format(placeholder="") + example["instance_prompt"] = example["instance_prompt"].format(placeholder="") + + if random.random() < self.random_target_prob: + random_path = random.choice(self.all_paths) + + random_image = load_image(random_path) + example["instance_images"] = self.image_transforms(random_image) + if self.prompt_in_filename: + example["instance_prompt"] = self._path_to_prompt(random_path) + + + if self.use_only_decoder_prompts: + example["encoder_prompt"] = example["instance_prompt"] + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + encoder_pixel_values = [example["encoder_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + encoder_prompts = [example["encoder_prompt"] for example in examples] + + if with_prior_preservation: + raise NotImplementedError("Prior preservation not implemented.") + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + encoder_pixel_values = torch.stack(encoder_pixel_values) + encoder_pixel_values = encoder_pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "encoder_pixel_values": encoder_pixel_values, + "prompts": prompts, "encoder_prompts": encoder_prompts} + return batch diff --git a/docs/.DS_Store b/docs/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ecb218a51788f964c73e12aac69933b3b8193ec9 Binary files /dev/null and b/docs/.DS_Store differ diff --git a/docs/static/.DS_Store b/docs/static/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..f99ae43bb14d60ad97b7e197cf9798c9be86ac69 Binary files /dev/null and b/docs/static/.DS_Store differ diff --git a/environment.yaml b/environment.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ce87f39711264bd96d97c60530d8ed707bb80ea --- /dev/null +++ b/environment.yaml @@ -0,0 +1,37 @@ +name: instantir +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - numpy + - pandas + - pillow + - pip + - python=3.9.15 + - pytorch=2.2.2 + - pytorch-lightning=1.6.5 + - pytorch-cuda=12.1 + - setuptools + - torchaudio=2.2.2 + - torchmetrics + - torchvision=0.17.2 + - tqdm + - pip: + - accelerate==0.25.0 + - diffusers==0.24.0 + - einops + - open-clip-torch + - opencv-python==4.8.1.78 + - tokenizers + - transformers==4.36.2 + - kornia + - facenet_pytorch + - lpips + - dreamsim + - pyrallis + - wandb + - insightface + - onnxruntime==1.17.0 + - -e git+https://github.com/openai/CLIP.git@main#egg=clip \ No newline at end of file diff --git a/gradio_demo/app.py b/gradio_demo/app.py new file mode 100644 index 0000000000000000000000000000000000000000..d5301676376eb1e1c2fbb8545651ca5fc38d3947 --- /dev/null +++ b/gradio_demo/app.py @@ -0,0 +1,250 @@ +import os +import sys +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +import torch +import numpy as np +import gradio as gr +from PIL import Image + +from diffusers import DDPMScheduler +from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler + +from module.ip_adapter.utils import load_adapter_to_pipe +from pipelines.sdxl_instantir import InstantIRPipeline + +def resize_img(input_image, max_side=1280, min_side=1024, size=None, + pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + # ratio = min_side / min(h, w) + # w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + +instantir_path = os.environ['INSTANTIR_PATH'] + +device = "cuda" if torch.cuda.is_available() else "cpu" +sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0" +dinov2_repo_id = "facebook/dinov2-large" +lcm_repo_id = "latent-consistency/lcm-lora-sdxl" + +if torch.cuda.is_available(): + torch_dtype = torch.float16 +else: + torch_dtype = torch.float32 + +# Load pretrained models. +print("Initializing pipeline...") +pipe = InstantIRPipeline.from_pretrained( + sdxl_repo_id, + torch_dtype=torch_dtype, +) + +# Image prompt projector. +print("Loading LQ-Adapter...") +load_adapter_to_pipe( + pipe, + f"{instantir_path}/adapter.pt", + dinov2_repo_id, +) + +# Prepare previewer +lora_alpha = pipe.prepare_previewers(instantir_path) +print(f"use lora alpha {lora_alpha}") +lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True) +print(f"use lora alpha {lora_alpha}") +pipe.to(device=device, dtype=torch_dtype) +pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler") +lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) + +# Load weights. +print("Loading checkpoint...") +aggregator_state_dict = torch.load( + f"{instantir_path}/aggregator.pt", + map_location="cpu" +) +pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True) +pipe.aggregator.to(device=device, dtype=torch_dtype) + +MAX_SEED = np.iinfo(np.int32).max +MAX_IMAGE_SIZE = 1024 + +PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \ +ultra HD, extreme meticulous detailing, skin pore detailing, \ +hyper sharpness, perfect without deformations, \ +taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. " + +NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \ +sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \ +dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \ +watermark, signature, jpeg artifacts, deformed, lowres" + +def unpack_pipe_out(preview_row, index): + return preview_row[index][0] + +def dynamic_preview_slider(sampling_steps): + print(sampling_steps) + return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1) + +def dynamic_guidance_slider(sampling_steps): + return gr.Slider(label="Start Free Rendering", value=sampling_steps, minimum=0, maximum=sampling_steps, step=1) + +def show_final_preview(preview_row): + return preview_row[-1][0] + +# @spaces.GPU #[uncomment to use ZeroGPU] +@torch.no_grad() +def instantir_restore( + lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0, + creative_restoration=False, seed=3407, height=1024, width=1024, preview_start=0.0): + if creative_restoration: + if "lcm" not in pipe.unet.active_adapters(): + pipe.unet.set_adapter('lcm') + else: + if "previewer" not in pipe.unet.active_adapters(): + pipe.unet.set_adapter('previewer') + + if isinstance(guidance_end, int): + guidance_end = guidance_end / steps + elif guidance_end > 1.0: + guidance_end = guidance_end / steps + if isinstance(preview_start, int): + preview_start = preview_start / steps + elif preview_start > 1.0: + preview_start = preview_start / steps + lq = [resize_img(lq.convert("RGB"), size=(width, height))] + generator = torch.Generator(device=device).manual_seed(seed) + timesteps = [ + i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps) + ] + timesteps = timesteps[::-1] + + prompt = PROMPT if len(prompt)==0 else prompt + neg_prompt = NEG_PROMPT + + out = pipe( + prompt=[prompt]*len(lq), + image=lq, + num_inference_steps=steps, + generator=generator, + timesteps=timesteps, + negative_prompt=[neg_prompt]*len(lq), + guidance_scale=cfg_scale, + control_guidance_end=guidance_end, + preview_start=preview_start, + previewer_scheduler=lcm_scheduler, + return_dict=False, + save_preview_row=True, + ) + for i, preview_img in enumerate(out[1]): + preview_img.append(f"preview_{i}") + return out[0][0], out[1] + +examples = [ + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "An astronaut riding a green horse", + "A delicious ceviche cheesecake slice", +] + +css=""" +#col-container { + margin: 0 auto; + max-width: 640px; +} +""" + +with gr.Blocks() as demo: + gr.Markdown( + """ + # InstantIR: Blind Image Restoration with Instant Generative Reference. + + ### **Official ๐Ÿค— Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).** + ### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!** + ## Basic usage: revitalize your image + 1. Upload an image you want to restore; + 2. Optionally, tune the `Steps` `CFG Scale` parameters. Typically higher steps lead to better results, but less than 50 is recommended for efficiency; + 3. Click `InstantIR magic!`. + """) + with gr.Row(): + lq_img = gr.Image(label="Low-quality image", type="pil") + with gr.Column(): + with gr.Row(): + steps = gr.Number(label="Steps", value=30, step=1) + cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1) + with gr.Row(): + height = gr.Number(label="Height", value=1024, step=1) + weight = gr.Number(label="Weight", value=1024, step=1) + seed = gr.Number(label="Seed", value=42, step=1) + # guidance_start = gr.Slider(label="Guidance Start", value=1.0, minimum=0.0, maximum=1.0, step=0.05) + guidance_end = gr.Slider(label="Start Free Rendering", value=30, minimum=0, maximum=30, step=1) + preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1) + prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="") + mode = gr.Checkbox(label="Creative Restoration", value=False) + with gr.Row(): + with gr.Row(): + restore_btn = gr.Button("InstantIR magic!") + clear_btn = gr.ClearButton() + index = gr.Slider(label="Restoration Previews", value=29, minimum=0, maximum=29, step=1) + with gr.Row(): + output = gr.Image(label="InstantIR restored", type="pil") + preview = gr.Image(label="Preview", type="pil") + pipe_out = gr.Gallery(visible=False) + clear_btn.add([lq_img, output, preview]) + restore_btn.click( + instantir_restore, inputs=[ + lq_img, prompt, steps, cfg_scale, guidance_end, + mode, seed, height, weight, preview_start, + ], + outputs=[output, pipe_out], api_name="InstantIR" + ) + steps.change(dynamic_guidance_slider, inputs=steps, outputs=guidance_end) + output.change(dynamic_preview_slider, inputs=steps, outputs=index) + index.release(unpack_pipe_out, inputs=[pipe_out, index], outputs=preview) + output.change(show_final_preview, inputs=pipe_out, outputs=preview) + gr.Markdown( + """ + ## Advance usage: + ### Browse restoration variants: + 1. After InstantIR processing, drag the `Restoration Previews` slider to explore other in-progress versions; + 2. If you like one of them, set the `Start Free Rendering` slider to the same value to get a more refined result. + ### Creative restoration: + 1. Check the `Creative Restoration` checkbox; + 2. Input your text prompts in the `Restoration prompts` textbox; + 3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation. + + ## Examples + Here are some examplar usage of InstantIR: + """) + # examples = gr.Gallery(label="Examples") + + gr.Markdown( + """ + ## Citation + If InstantIR is helpful to your work, please cite our paper via: + + ``` + @article{huang2024instantir, + title={InstantIR: Blind Image Restoration with Instant Generative Reference}, + author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse}, + journal={arXiv preprint arXiv:2410.06551}, + year={2024} + } + ``` + """) + +demo.queue().launch() \ No newline at end of file diff --git a/infer.py b/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa021623a6a6a0e90acc271483f7413dfe24bd60 --- /dev/null +++ b/infer.py @@ -0,0 +1,381 @@ +import os +import argparse +import numpy as np +import torch + +from PIL import Image +from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler + +from diffusers import DDPMScheduler + +from module.ip_adapter.utils import load_adapter_to_pipe +from pipelines.sdxl_instantir import InstantIRPipeline + + +def name_unet_submodules(unet): + def recursive_find_module(name, module, end=False): + if end: + for sub_name, sub_module in module.named_children(): + sub_module.full_name = f"{name}.{sub_name}" + return + if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return + elif "resnets" in name: return + for sub_name, sub_module in module.named_children(): + end = True if sub_name == "transformer_blocks" else False + recursive_find_module(f"{name}.{sub_name}", sub_module, end) + + for name, module in unet.named_children(): + recursive_find_module(name, module) + + +def resize_img(input_image, max_side=1280, min_side=1024, size=None, + pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): + + w, h = input_image.size + if size is not None: + w_resize_new, h_resize_new = size + else: + # ratio = min_side / min(h, w) + # w, h = round(ratio*w), round(ratio*h) + ratio = max_side / max(h, w) + input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) + w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number + h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number + input_image = input_image.resize([w_resize_new, h_resize_new], mode) + + if pad_to_max_side: + res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 + offset_x = (max_side - w_resize_new) // 2 + offset_y = (max_side - h_resize_new) // 2 + res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) + input_image = Image.fromarray(res) + return input_image + + +def tensor_to_pil(images): + """ + Convert image tensor or a batch of image tensors to PIL image(s). + """ + images = images.clamp(0, 1) + images_np = images.detach().cpu().numpy() + if images_np.ndim == 4: + images_np = np.transpose(images_np, (0, 2, 3, 1)) + elif images_np.ndim == 3: + images_np = np.transpose(images_np, (1, 2, 0)) + images_np = images_np[None, ...] + images_np = (images_np * 255).round().astype("uint8") + if images_np.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np] + else: + pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np] + + return pil_images + + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +def main(args, device): + + # Load pretrained models. + pipe = InstantIRPipeline.from_pretrained( + args.sdxl_path, + torch_dtype=torch.float16, + ) + + # Image prompt projector. + print("Loading LQ-Adapter...") + load_adapter_to_pipe( + pipe, + args.adapter_model_path if args.adapter_model_path is not None else os.path.join(args.instantir_path, 'adapter.pt'), + args.vision_encoder_path, + use_clip_encoder=args.use_clip_encoder, + ) + + # Prepare previewer + previewer_lora_path = args.previewer_lora_path if args.previewer_lora_path is not None else args.instantir_path + if previewer_lora_path is not None: + lora_alpha = pipe.prepare_previewers(previewer_lora_path) + print(f"use lora alpha {lora_alpha}") + pipe.to(device=device, dtype=torch.float16) + pipe.scheduler = DDPMScheduler.from_pretrained(args.sdxl_path, subfolder="scheduler") + lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) + + # Load weights. + print("Loading checkpoint...") + pretrained_state_dict = torch.load(os.path.join(args.instantir_path, "aggregator.pt"), map_location="cpu") + pipe.aggregator.load_state_dict(pretrained_state_dict) + pipe.aggregator.to(device, dtype=torch.float16) + + #################### Restoration #################### + + post_fix = f"_{args.post_fix}" if args.post_fix else "" + os.makedirs(f"{args.out_path}/{post_fix}", exist_ok=True) + + processed_imgs = os.listdir(os.path.join(args.out_path, post_fix)) + lq_files = [] + lq_batch = [] + if os.path.isfile(args.test_path): + all_inputs = [args.test_path.split("/")[-1]] + else: + all_inputs = os.listdir(args.test_path) + all_inputs.sort() + for file in all_inputs: + if file in processed_imgs: + print(f"Skip {file}") + continue + lq_batch.append(f"{file}") + if len(lq_batch) == args.batch_size: + lq_files.append(lq_batch) + lq_batch = [] + + if len(lq_batch) > 0: + lq_files.append(lq_batch) + + for lq_batch in lq_files: + generator = torch.Generator(device=device).manual_seed(args.seed) + pil_lqs = [Image.open(os.path.join(args.test_path, file)) for file in lq_batch] + if args.width is None or args.height is None: + lq = [resize_img(pil_lq.convert("RGB"), size=None) for pil_lq in pil_lqs] + else: + lq = [resize_img(pil_lq.convert("RGB"), size=(args.width, args.height)) for pil_lq in pil_lqs] + timesteps = None + if args.denoising_start < 1000: + timesteps = [ + i * (args.denoising_start//args.num_inference_steps) + pipe.scheduler.config.steps_offset for i in range(0, args.num_inference_steps) + ] + timesteps = timesteps[::-1] + pipe.scheduler.set_timesteps(args.num_inference_steps, device) + timesteps = pipe.scheduler.timesteps + if args.prompt is None or len(args.prompt) == 0: + prompt = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \ + ultra HD, extreme meticulous detailing, skin pore detailing, \ + hyper sharpness, perfect without deformations, \ + taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. " + else: + prompt = args.prompt + if not isinstance(prompt, list): + prompt = [prompt] + prompt = prompt*len(lq) + if args.neg_prompt is None or len(args.neg_prompt) == 0: + neg_prompt = "blurry, out of focus, unclear, depth of field, over-smooth, \ + sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \ + dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \ + watermark, signature, jpeg artifacts, deformed, lowres" + else: + neg_prompt = args.neg_prompt + if not isinstance(neg_prompt, list): + neg_prompt = [neg_prompt] + neg_prompt = neg_prompt*len(lq) + image = pipe( + prompt=prompt, + image=lq, + num_inference_steps=args.num_inference_steps, + generator=generator, + timesteps=timesteps, + negative_prompt=neg_prompt, + guidance_scale=args.cfg, + previewer_scheduler=lcm_scheduler, + preview_start=args.preview_start, + control_guidance_end=args.creative_start, + ).images + + if args.save_preview_row: + for i, lcm_image in enumerate(image[1]): + lcm_image.save(f"./lcm/{i}.png") + for i, rec_image in enumerate(image): + rec_image.save(f"{args.out_path}/{post_fix}/{lq_batch[i]}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="InstantIR pipeline") + parser.add_argument( + "--sdxl_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--previewer_lora_path", + type=str, + default=None, + help="Path to LCM lora or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--instantir_path", + type=str, + default=None, + required=True, + help="Path to pretrained instantir model.", + ) + parser.add_argument( + "--vision_encoder_path", + type=str, + default='/share/huangrenyuan/model_zoo/vis_backbone/dinov2_large', + help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--adapter_model_path", + type=str, + default=None, + help="Path to IP-Adapter models or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--adapter_tokens", + type=int, + default=64, + help="Number of tokens to use in IP-adapter cross attention mechanism.", + ) + parser.add_argument( + "--use_clip_encoder", + action="store_true", + help="Whether or not to use DINO as image encoder, else CLIP encoder.", + ) + parser.add_argument( + "--denoising_start", + type=int, + default=1000, + help="Diffusion start timestep." + ) + parser.add_argument( + "--num_inference_steps", + type=int, + default=30, + help="Diffusion steps." + ) + parser.add_argument( + "--creative_start", + type=float, + default=1.0, + help="Proportion of timesteps for creative restoration. 1.0 means no creative restoration while 0.0 means completely free rendering." + ) + parser.add_argument( + "--preview_start", + type=float, + default=0.0, + help="Proportion of timesteps to stop previewing at the begining to enhance fidelity to input." + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + help="Number of tokens to use in IP-adapter cross attention mechanism.", + ) + parser.add_argument( + "--batch_size", + type=int, + default=6, + help="Test batch size." + ) + parser.add_argument( + "--width", + type=int, + default=None, + help="Output image width." + ) + parser.add_argument( + "--height", + type=int, + default=None, + help="Output image height." + ) + parser.add_argument( + "--cfg", + type=float, + default=7.0, + help="Scale of Classifier-Free-Guidance (CFG).", + ) + parser.add_argument( + "--post_fix", + type=str, + default=None, + help="Subfolder name for restoration output under the output directory.", + ) + parser.add_argument( + "--variant", + type=str, + default='fp16', + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--save_preview_row", + action="store_true", + help="Whether or not to save the intermediate lcm outputs.", + ) + parser.add_argument( + "--prompt", + type=str, + default='', + nargs="+", + help=( + "A set of prompts for creative restoration. Provide either a matching number of test images," + " or a single prompt to be used with all inputs." + ), + ) + parser.add_argument( + "--neg_prompt", + type=str, + default='', + nargs="+", + help=( + "A set of negative prompts for creative restoration. Provide either a matching number of test images," + " or a single negative prompt to be used with all inputs." + ), + ) + parser.add_argument( + "--test_path", + type=str, + default=None, + required=True, + help="Test directory.", + ) + parser.add_argument( + "--out_path", + type=str, + default="./output", + help="Output directory.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + args = parser.parse_args() + args.height = args.height or args.width + args.width = args.width or args.height + if args.height is not None and (args.width % 64 != 0 or args.height % 64 != 0): + raise ValueError("Image resolution must be divisible by 64.") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + main(args, device) \ No newline at end of file diff --git a/infer.sh b/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..9d7e54337d06dd33b857ce4149d164e8ac8cdb3a --- /dev/null +++ b/infer.sh @@ -0,0 +1,6 @@ +python infer.py \ + --sdxl_path path/to/sdxl \ + --vision_encoder_path path/to/dinov2_large \ + --instantir_path path/to/instantir \ + --test_path path/to/input \ + --out_path path/to/output \ No newline at end of file diff --git a/losses/loss_config.py b/losses/loss_config.py new file mode 100644 index 0000000000000000000000000000000000000000..64a384818982b598f484eeaf908ec91c1782782f --- /dev/null +++ b/losses/loss_config.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass, field +from typing import List + +@dataclass +class SingleLossConfig: + name: str + weight: float = 1. + init_params: dict = field(default_factory=dict) + visualize_every_k: int = -1 + + +@dataclass +class LossesConfig: + diffusion_losses: List[SingleLossConfig] + lcm_losses: List[SingleLossConfig] \ No newline at end of file diff --git a/losses/losses.py b/losses/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..71de2791906d1f68871a16974bb2cd9d7fabbf37 --- /dev/null +++ b/losses/losses.py @@ -0,0 +1,465 @@ +import torch +import wandb +import cv2 +import torch.nn.functional as F +import numpy as np +from facenet_pytorch import MTCNN +from torchvision import transforms +from dreamsim import dreamsim +from einops import rearrange +import kornia.augmentation as K +import lpips + +from pretrained_models.arcface import Backbone +from utils.vis_utils import add_text_to_image +from utils.utils import extract_faces_and_landmarks +import clip + + +class Loss(): + """ + General purpose loss class. + Mainly handles dtype and visualize_every_k. + keeps current iteration of loss, mainly for visualization purposes. + """ + def __init__(self, visualize_every_k=-1, dtype=torch.float32, accelerator=None, **kwargs): + self.visualize_every_k = visualize_every_k + self.iteration = -1 + self.dtype=dtype + self.accelerator = accelerator + + def __call__(self, **kwargs): + self.iteration += 1 + return self.forward(**kwargs) + + +class L1Loss(Loss): + """ + Simple L1 loss between predicted_pixel_values and pixel_values + + Args: + predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. + encoder_pixel_values (torch.Tesnor): The input image to the encoder + """ + def forward( + self, + predict: torch.Tensor, + target: torch.Tensor, + **kwargs + ) -> torch.Tensor: + return F.l1_loss(predict, target, reduction="mean") + + +class DreamSIMLoss(Loss): + """DreamSIM loss between predicted_pixel_values and pixel_values. + DreamSIM is similar to LPIPS (https://dreamsim-nights.github.io/) but is trained on more human defined similarity dataset + DreamSIM expects an RGB image of size 224x224 and values between 0 and 1. So we need to normalize the input images to 0-1 range and resize them to 224x224. + Args: + predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. + encoder_pixel_values (torch.Tesnor): The input image to the encoder + """ + def __init__(self, device: str='cuda:0', **kwargs): + super().__init__(**kwargs) + self.model, _ = dreamsim(pretrained=True, device=device) + self.model.to(dtype=self.dtype, device=device) + self.model = self.accelerator.prepare(self.model) + self.transforms = transforms.Compose([ + transforms.Lambda(lambda x: (x + 1) / 2), + transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)]) + + def forward( + self, + predicted_pixel_values: torch.Tensor, + encoder_pixel_values: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + predicted_pixel_values.to(dtype=self.dtype) + encoder_pixel_values.to(dtype=self.dtype) + return self.model(self.transforms(predicted_pixel_values), self.transforms(encoder_pixel_values)).mean() + + +class LPIPSLoss(Loss): + """LPIPS loss between predicted_pixel_values and pixel_values. + Args: + predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. + encoder_pixel_values (torch.Tesnor): The input image to the encoder + """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model = lpips.LPIPS(net='vgg') + self.model.to(dtype=self.dtype, device=self.accelerator.device) + self.model = self.accelerator.prepare(self.model) + + def forward(self, predict, target, **kwargs): + predict.to(dtype=self.dtype) + target.to(dtype=self.dtype) + return self.model(predict, target).mean() + + +class LCMVisualization(Loss): + """Dummy loss used to visualize the LCM outputs + Args: + predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. + pixel_values (torch.Tensor): The input image to the decoder + encoder_pixel_values (torch.Tesnor): The input image to the encoder + """ + def forward( + self, + predicted_pixel_values: torch.Tensor, + pixel_values: torch.Tensor, + encoder_pixel_values: torch.Tensor, + timesteps: torch.Tensor, + **kwargs, + ) -> None: + if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: + predicted_pixel_values = rearrange(predicted_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() + pixel_values = rearrange(pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() + encoder_pixel_values = rearrange(encoder_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy() + image = np.hstack([encoder_pixel_values, pixel_values, predicted_pixel_values]) + for tracker in self.accelerator.trackers: + if tracker.name == 'wandb': + tracker.log({"TrainVisualization": wandb.Image(image, caption=f"Encoder Input Image, Decoder Input Image, Predicted LCM Image. Timesteps {timesteps.cpu().tolist()}")}) + return torch.tensor(0.0) + + +class L2Loss(Loss): + """ + Regular diffusion loss between predicted noise and target noise. + + Args: + predicted_noise (torch.Tensor): noise predicted by the diffusion model + target_noise (torch.Tensor): actual noise added to the image. + """ + def forward( + self, + predict: torch.Tensor, + target: torch.Tensor, + weights: torch.Tensor = None, + **kwargs + ) -> torch.Tensor: + if weights is not None: + loss = (predict.float() - target.float()).pow(2) * weights + return loss.mean() + return F.mse_loss(predict.float(), target.float(), reduction="mean") + + +class HuberLoss(Loss): + """Huber loss between predicted_pixel_values and pixel_values. + Args: + predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder. + encoder_pixel_values (torch.Tesnor): The input image to the encoder + """ + def __init__(self, huber_c=0.001, **kwargs): + super().__init__(**kwargs) + self.huber_c = huber_c + + def forward( + self, + predict: torch.Tensor, + target: torch.Tensor, + weights: torch.Tensor = None, + **kwargs + ) -> torch.Tensor: + loss = torch.sqrt((predict.float() - target.float()) ** 2 + self.huber_c**2) - self.huber_c + if weights is not None: + return (loss * weights).mean() + return loss.mean() + + +class WeightedNoiseLoss(Loss): + """ + Weighted diffusion loss between predicted noise and target noise. + + Args: + predicted_noise (torch.Tensor): noise predicted by the diffusion model + target_noise (torch.Tensor): actual noise added to the image. + loss_batch_weights (torch.Tensor): weighting for each batch item. Can be used to e.g. zero-out loss for InstantID training if keypoint extraction fails. + """ + def forward( + self, + predict: torch.Tensor, + target: torch.Tensor, + weights, + **kwargs + ) -> torch.Tensor: + return F.mse_loss(predict.float() * weights, target.float() * weights, reduction="mean") + + +class IDLoss(Loss): + """ + Use pretrained facenet model to extract features from the face of the predicted image and target image. + Facenet expects 112x112 images, so we crop the face using MTCNN and resize it to 112x112. + Then we use the cosine similarity between the features to calculate the loss. (The cosine similarity is 1 - cosine distance). + Also notice that the outputs of facenet are normalized so the dot product is the same as cosine distance. + """ + def __init__(self, pretrained_arcface_path: str, skip_not_found=True, **kwargs): + super().__init__(**kwargs) + assert pretrained_arcface_path is not None, "please pass `pretrained_arcface_path` in the losses config. You can download the pretrained model from "\ + "https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing" + self.mtcnn = MTCNN(device=self.accelerator.device) + self.mtcnn.forward = self.mtcnn.detect + self.facenet_input_size = 112 # Has to be 112, can't find weights for 224 size. + self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') + self.facenet.load_state_dict(torch.load(pretrained_arcface_path)) + self.face_pool = torch.nn.AdaptiveAvgPool2d((self.facenet_input_size, self.facenet_input_size)) + self.facenet.requires_grad_(False) + self.facenet.eval() + self.facenet.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision + self.face_pool.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision + self.visualization_resize = transforms.Resize((self.facenet_input_size, self.facenet_input_size), interpolation=transforms.InterpolationMode.BICUBIC) + self.reference_facial_points = np.array([[38.29459953, 51.69630051], + [72.53179932, 51.50139999], + [56.02519989, 71.73660278], + [41.54930115, 92.3655014], + [70.72990036, 92.20410156] + ]) # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112 + self.facenet, self.face_pool, self.mtcnn = self.accelerator.prepare(self.facenet, self.face_pool, self.mtcnn) + + self.skip_not_found = skip_not_found + + def extract_feats(self, x: torch.Tensor): + """ + Extract features from the face of the image using facenet model. + """ + x = self.face_pool(x) + x_feats = self.facenet(x) + + return x_feats + + def forward( + self, + predicted_pixel_values: torch.Tensor, + encoder_pixel_values: torch.Tensor, + timesteps: torch.Tensor, + **kwargs + ): + encoder_pixel_values = encoder_pixel_values.to(dtype=self.dtype) + predicted_pixel_values = predicted_pixel_values.to(dtype=self.dtype) + + predicted_pixel_values_face, predicted_invalid_indices = extract_faces_and_landmarks(predicted_pixel_values, mtcnn=self.mtcnn) + with torch.no_grad(): + encoder_pixel_values_face, source_invalid_indices = extract_faces_and_landmarks(encoder_pixel_values, mtcnn=self.mtcnn) + + if self.skip_not_found: + valid_indices = [] + for i in range(predicted_pixel_values.shape[0]): + if i not in predicted_invalid_indices and i not in source_invalid_indices: + valid_indices.append(i) + else: + valid_indices = list(range(predicted_pixel_values)) + + valid_indices = torch.tensor(valid_indices).to(device=predicted_pixel_values.device) + + if len(valid_indices) == 0: + loss = (predicted_pixel_values_face * 0.0).mean() # It's done this way so the `backwards` will delete the computation graph of the predicted_pixel_values. + if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: + self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss) + return loss + + with torch.no_grad(): + pixel_values_feats = self.extract_feats(encoder_pixel_values_face[valid_indices]) + + predicted_pixel_values_feats = self.extract_feats(predicted_pixel_values_face[valid_indices]) + loss = 1 - torch.einsum("bi,bi->b", pixel_values_feats, predicted_pixel_values_feats) + + if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0: + self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss) + return loss.mean() + + def visualize( + self, + predicted_pixel_values: torch.Tensor, + encoder_pixel_values: torch.Tensor, + predicted_pixel_values_face: torch.Tensor, + encoder_pixel_values_face: torch.Tensor, + timesteps: torch.Tensor, + valid_indices: torch.Tensor, + loss: torch.Tensor, + ) -> None: + small_predicted_pixel_values = (rearrange(self.visualization_resize(predicted_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy()) + small_pixle_values = rearrange(self.visualization_resize(encoder_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy() + small_predicted_pixel_values_face = rearrange(self.visualization_resize(predicted_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy() + small_pixle_values_face = rearrange(self.visualization_resize(encoder_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy() + + small_predicted_pixel_values = add_text_to_image(((small_predicted_pixel_values * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Images", add_below=False) + small_pixle_values = add_text_to_image(((small_pixle_values * 0.5 + 0.5) * 255).astype(np.uint8), "Target Images", add_below=False) + small_predicted_pixel_values_face = add_text_to_image(((small_predicted_pixel_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Faces", add_below=False) + small_pixle_values_face = add_text_to_image(((small_pixle_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Target Faces", add_below=False) + + + final_image = np.hstack([small_predicted_pixel_values, small_pixle_values, small_predicted_pixel_values_face, small_pixle_values_face]) + for tracker in self.accelerator.trackers: + if tracker.name == 'wandb': + tracker.log({"IDLoss Visualization": wandb.Image(final_image, caption=f"loss: {loss.cpu().tolist()} timesteps: {timesteps.cpu().tolist()}, valid_indices: {valid_indices.cpu().tolist()}")}) + + +class ImageAugmentations(torch.nn.Module): + # Standard image augmentations used for CLIP loss to discourage adversarial outputs. + def __init__(self, output_size, augmentations_number, p=0.7): + super().__init__() + self.output_size = output_size + self.augmentations_number = augmentations_number + + self.augmentations = torch.nn.Sequential( + K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore + K.RandomPerspective(0.7, p=p), + ) + + self.avg_pool = torch.nn.AdaptiveAvgPool2d((self.output_size, self.output_size)) + + self.device = None + + def forward(self, input): + """Extents the input batch with augmentations + If the input is consists of images [I1, I2] the extended augmented output + will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...] + Args: + input ([type]): input batch of shape [batch, C, H, W] + Returns: + updated batch: of shape [batch * augmentations_number, C, H, W] + """ + # We want to multiply the number of images in the batch in contrast to regular augmantations + # that do not change the number of samples in the batch) + resized_images = self.avg_pool(input) + resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1)) + + batch_size = input.shape[0] + # We want at least one non augmented image + non_augmented_batch = resized_images[:batch_size] + augmented_batch = self.augmentations(resized_images[batch_size:]) + updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0) + + return updated_batch + + +class CLIPLoss(Loss): + def __init__(self, augmentations_number: int = 4, **kwargs): + super().__init__(**kwargs) + + self.clip_model, clip_preprocess = clip.load("ViT-B/16", device=self.accelerator.device, jit=False) + + self.clip_model.device = None + + self.clip_model.eval().requires_grad_(False) + + self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1]. + clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions + clip_preprocess.transforms[4:]) # + skip convert PIL to tensor + + self.clip_size = self.clip_model.visual.input_resolution + + self.clip_normalize = transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] + ) + + self.image_augmentations = ImageAugmentations(output_size=self.clip_size, + augmentations_number=augmentations_number) + + self.clip_model, self.image_augmentations = self.accelerator.prepare(self.clip_model, self.image_augmentations) + + def forward(self, decoder_prompts, predicted_pixel_values: torch.Tensor, **kwargs) -> torch.Tensor: + + if not isinstance(decoder_prompts, list): + decoder_prompts = [decoder_prompts] + + tokens = clip.tokenize(decoder_prompts).to(predicted_pixel_values.device) + image = self.preprocess(predicted_pixel_values) + + logits_per_image, _ = self.clip_model(image, tokens) + + logits_per_image = torch.diagonal(logits_per_image) + + return (1. - logits_per_image / 100).mean() + + +class DINOLoss(Loss): + def __init__( + self, + dino_model, + dino_preprocess, + output_hidden_states: bool = False, + center_momentum: float = 0.9, + student_temp: float = 0.1, + teacher_temp: float = 0.04, + warmup_teacher_temp: float = 0.04, + warmup_teacher_temp_epochs: int = 30, + **kwargs): + super().__init__(**kwargs) + + self.dino_model = dino_model + self.output_hidden_states = output_hidden_states + self.rescale_factor = dino_preprocess.rescale_factor + + # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1]. + self.preprocess = transforms.Compose( + [ + transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]), + transforms.Resize(size=256), + transforms.CenterCrop(size=(224, 224)), + transforms.Normalize(mean=dino_preprocess.image_mean, std=dino_preprocess.image_std) + ] + ) + + self.student_temp = student_temp + self.teacher_temp = teacher_temp + self.center_momentum = center_momentum + self.center = torch.zeros(1, 257, 1024).to(self.accelerator.device, dtype=self.dtype) + + # TODO: add temp, now fixed to 0.04 + # we apply a warm up for the teacher temperature because + # a too high temperature makes the training instable at the beginning + # self.teacher_temp_schedule = np.concatenate(( + # np.linspace(warmup_teacher_temp, + # teacher_temp, warmup_teacher_temp_epochs), + # np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp + # )) + + self.dino_model = self.accelerator.prepare(self.dino_model) + + def forward( + self, + target: torch.Tensor, + predict: torch.Tensor, + weights: torch.Tensor = None, + **kwargs) -> torch.Tensor: + + predict = self.preprocess(predict) + target = self.preprocess(target) + + encoder_input = torch.cat([target, predict]).to(self.dino_model.device, dtype=self.dino_model.dtype) + + if self.output_hidden_states: + raise ValueError("Output hidden states not supported for DINO loss.") + image_enc_hidden_states = self.dino_model(encoder_input, output_hidden_states=True).hidden_states[-2] + else: + image_enc_hidden_states = self.dino_model(encoder_input).last_hidden_state + + teacher_output, student_output = image_enc_hidden_states.chunk(2, dim=0) # [B, 257, 1024] + + student_out = student_output.float() / self.student_temp + + # teacher centering and sharpening + # temp = self.teacher_temp_schedule[epoch] + temp = self.teacher_temp + teacher_out = F.softmax((teacher_output.float() - self.center) / temp, dim=-1) + teacher_out = teacher_out.detach() + + loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1, keepdim=True) + # self.update_center(teacher_output) + + if weights is not None: + loss = loss * weights + return loss.mean() + return loss.mean() + + @torch.no_grad() + def update_center(self, teacher_output): + """ + Update center used for teacher output. + """ + batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + self.accelerator.reduce(batch_center, reduction="sum") + batch_center = batch_center / (len(teacher_output) * self.accelerator.num_processes) + + # ema update + self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) diff --git a/module/aggregator.py b/module/aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..fd615100376e019a2fd6dd6a94d176f6047c1762 --- /dev/null +++ b/module/aggregator.py @@ -0,0 +1,983 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, +) +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class ZeroConv(nn.Module): + def __init__(self, label_nc, norm_nc, mask=False): + super().__init__() + self.zero_conv = zero_module(nn.Conv2d(label_nc+norm_nc, norm_nc, 1, 1, 0)) + self.mask = mask + + def forward(self, hidden_states, h_ori=None): + # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): + c, h = hidden_states + if not self.mask: + h = self.zero_conv(torch.cat([c, h], dim=1)) + else: + h = self.zero_conv(torch.cat([c, h], dim=1)) * torch.zeros_like(h) + if h_ori is not None: + h = torch.cat([h_ori, h], dim=1) + return h + + +class SFT(nn.Module): + def __init__(self, label_nc, norm_nc, mask=False): + super().__init__() + + # param_free_norm_type = str(parsed.group(1)) + ks = 3 + pw = ks // 2 + + self.mask = mask + + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), + nn.SiLU() + ) + self.mul = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + self.add = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) + + def forward(self, hidden_states, mask=False): + + c, h = hidden_states + mask = mask or self.mask + assert mask is False + + actv = self.mlp_shared(c) + gamma = self.mul(actv) + beta = self.add(actv) + + if self.mask: + gamma = gamma * torch.zeros_like(gamma) + beta = beta * torch.zeros_like(beta) + # gamma_ori, gamma_res = torch.split(gamma, [h_ori_c, h_c], dim=1) + # beta_ori, beta_res = torch.split(beta, [h_ori_c, h_c], dim=1) + # print(gamma_ori.mean(), gamma_res.mean(), beta_ori.mean(), beta_res.mean()) + h = h * (gamma + 1) + beta + # sample_ori, sample_res = torch.split(h, [h_ori_c, h_c], dim=1) + # print(sample_ori.mean(), sample_res.mean()) + + return h + + +@dataclass +class AggregatorOutput(BaseOutput): + """ + The output of [`Aggregator`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class ConditioningEmbedding(nn.Module): + """ + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN + [11] to convert the entire dataset of 512 ร— 512 images into smaller 64 ร— 64 โ€œlatent imagesโ€ for stabilized + training. This requires ControlNets to convert image-based conditions to 64 ร— 64 feature space to match the + convolution size. We use a tiny network E(ยท) of four convolution layers with 4 ร— 4 kernels and 2 ร— 2 strides + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full + model) to encode image-space conditions ... into feature maps ..." + """ + + def __init__( + self, + conditioning_embedding_channels: int, + conditioning_channels: int = 3, + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), + ): + super().__init__() + + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) + + self.blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels) - 1): + channel_in = block_out_channels[i] + channel_out = block_out_channels[i + 1] + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) + + self.conv_out = zero_module( + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) + ) + + def forward(self, conditioning): + embedding = self.conv_in(conditioning) + embedding = F.silu(embedding) + + for block in self.blocks: + embedding = block(embedding) + embedding = F.silu(embedding) + + embedding = self.conv_out(embedding) + + return embedding + + +class Aggregator(ModelMixin, ConfigMixin, FromOriginalModelMixin): + """ + Aggregator model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 3, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + pad_concat: bool = False, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + self.pad_concat = pad_concat + + # Check inputs + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + # control net conditioning embedding + self.ref_conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + self.down_blocks = nn.ModuleList([]) + self.controlnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + # controlnet_block = ZeroConv(output_channel, output_channel) + controlnet_block = nn.Sequential( + SFT(output_channel, output_channel), + zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1)) + ) + self.controlnet_down_blocks.append(controlnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + # controlnet_block = ZeroConv(output_channel, output_channel) + controlnet_block = nn.Sequential( + SFT(output_channel, output_channel), + zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1)) + ) + self.controlnet_down_blocks.append(controlnet_block) + + if not is_final_block: + # controlnet_block = ZeroConv(output_channel, output_channel) + controlnet_block = nn.Sequential( + SFT(output_channel, output_channel), + zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1)) + ) + self.controlnet_down_blocks.append(controlnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + # controlnet_block = ZeroConv(mid_block_channel, mid_block_channel) + controlnet_block = nn.Sequential( + SFT(mid_block_channel, mid_block_channel), + zero_module(nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)) + ) + self.controlnet_mid_block = controlnet_block + + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + controlnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 3, + ): + r""" + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=unet.config.in_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=unet.config.down_block_types, + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + mid_block_type=unet.config.mid_block_type, + controlnet_conditioning_channel_order=controlnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + conditioning_channels=conditioning_channels, + ) + + if load_weights_from_unet: + controlnet.conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.ref_conv_in.load_state_dict(unet.conv_in.state_dict()) + controlnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if controlnet.class_embedding: + controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + if hasattr(controlnet, "add_embedding"): + controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict()) + + controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict()) + controlnet.mid_block.load_state_dict(unet.mid_block.state_dict()) + + return controlnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + controlnet_cond: torch.FloatTensor, + cat_dim: int = -2, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[AggregatorOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: + """ + The [`Aggregator`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.controlnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + else: + raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. prepare input + cond_latent = self.conv_in(sample) + ref_latent = self.ref_conv_in(controlnet_cond) + batch_size, channel, height, width = cond_latent.shape + if self.pad_concat: + if cat_dim == -2 or cat_dim == 2: + concat_pad = torch.zeros(batch_size, channel, 1, width) + elif cat_dim == -1 or cat_dim == 3: + concat_pad = torch.zeros(batch_size, channel, height, 1) + else: + raise ValueError(f"Aggregator shall concat along spatial dimension, but is asked to concat dim: {cat_dim}.") + concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype) + sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim) + else: + sample = torch.cat([cond_latent, ref_latent], dim=cat_dim) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # rebuild sample: split and concat + if self.pad_concat: + batch_size, channel, height, width = sample.shape + if cat_dim == -2 or cat_dim == 2: + cond_latent = sample[:, :, :height//2, :] + ref_latent = sample[:, :, -(height//2):, :] + concat_pad = torch.zeros(batch_size, channel, 1, width) + elif cat_dim == -1 or cat_dim == 3: + cond_latent = sample[:, :, :, :width//2] + ref_latent = sample[:, :, :, -(width//2):] + concat_pad = torch.zeros(batch_size, channel, height, 1) + concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype) + sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim) + res_samples = res_samples[:-1] + (sample,) + + down_block_res_samples += res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + cross_attention_kwargs=cross_attention_kwargs, + ) + + # 5. split samples and SFT. + controlnet_down_block_res_samples = () + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + batch_size, channel, height, width = down_block_res_sample.shape + if cat_dim == -2 or cat_dim == 2: + cond_latent = down_block_res_sample[:, :, :height//2, :] + ref_latent = down_block_res_sample[:, :, -(height//2):, :] + elif cat_dim == -1 or cat_dim == 3: + cond_latent = down_block_res_sample[:, :, :, :width//2] + ref_latent = down_block_res_sample[:, :, :, -(width//2):] + down_block_res_sample = controlnet_block((cond_latent, ref_latent), ) + controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + batch_size, channel, height, width = sample.shape + if cat_dim == -2 or cat_dim == 2: + cond_latent = sample[:, :, :height//2, :] + ref_latent = sample[:, :, -(height//2):, :] + elif cat_dim == -1 or cat_dim == 3: + cond_latent = sample[:, :, :, :width//2] + ref_latent = sample[:, :, :, -(width//2):] + mid_block_res_sample = self.controlnet_mid_block((cond_latent, ref_latent), ) + + # 6. scaling + down_block_res_samples = [sample*conditioning_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample*conditioning_scale + + if self.config.global_pool_conditions: + down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples + ] + mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True) + + if not return_dict: + return (down_block_res_samples, mid_block_res_sample) + + return AggregatorOutput( + down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/module/attention.py b/module/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..12b5f2539b9e7f86a264fa7eb7c2fc3680379d28 --- /dev/null +++ b/module/attention.py @@ -0,0 +1,259 @@ +# Copy from diffusers.models.attention.py + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm + +from module.min_sdxl import LoRACompatibleLinear, LoRALinearLayer + + +logger = logging.get_logger(__name__) + +def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + +def maybe_grad_checkpoint(resnet, attn, hidden_states, temb, encoder_hidden_states, adapter_hidden_states, do_ckpt=True): + + if do_ckpt: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states, extracted_kv = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn), hidden_states, encoder_hidden_states, adapter_hidden_states, use_reentrant=False + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states, extracted_kv = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + adapter_hidden_states=adapter_hidden_states, + ) + return hidden_states, extracted_kv + + +def init_lora_in_attn(attn_module, rank: int = 4, is_kvcopy=False): + # Set the `lora_layer` attribute of the attention-related matrices. + + attn_module.to_k.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank + ) + ) + attn_module.to_v.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank + ) + ) + + if not is_kvcopy: + attn_module.to_q.set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank + ) + ) + + attn_module.to_out[0].set_lora_layer( + LoRALinearLayer( + in_features=attn_module.to_out[0].in_features, + out_features=attn_module.to_out[0].out_features, + rank=rank, + ) + ) + +def drop_kvs(encoder_kvs, drop_chance): + for layer in encoder_kvs: + len_tokens = encoder_kvs[layer].self_attention.k.shape[1] + idx_to_keep = (torch.rand(len_tokens) > drop_chance) + + encoder_kvs[layer].self_attention.k = encoder_kvs[layer].self_attention.k[:, idx_to_keep] + encoder_kvs[layer].self_attention.v = encoder_kvs[layer].self_attention.v[:, idx_to_keep] + + return encoder_kvs + +def clone_kvs(encoder_kvs): + cloned_kvs = {} + for layer in encoder_kvs: + sa_cpy = KVCache(k=encoder_kvs[layer].self_attention.k.clone(), + v=encoder_kvs[layer].self_attention.v.clone()) + + ca_cpy = KVCache(k=encoder_kvs[layer].cross_attention.k.clone(), + v=encoder_kvs[layer].cross_attention.v.clone()) + + cloned_layer_cache = AttentionCache(self_attention=sa_cpy, cross_attention=ca_cpy) + + cloned_kvs[layer] = cloned_layer_cache + + return cloned_kvs + + +class KVCache(object): + def __init__(self, k, v): + self.k = k + self.v = v + +class AttentionCache(object): + def __init__(self, self_attention: KVCache, cross_attention: KVCache): + self.self_attention = self_attention + self.cross_attention = cross_attention + +class KVCopy(nn.Module): + def __init__( + self, inner_dim, cross_attention_dim=None, + ): + super(KVCopy, self).__init__() + + in_dim = cross_attention_dim or inner_dim + + self.to_k = LoRACompatibleLinear(in_dim, inner_dim, bias=False) + self.to_v = LoRACompatibleLinear(in_dim, inner_dim, bias=False) + + def forward(self, hidden_states): + + k = self.to_k(hidden_states) + v = self.to_v(hidden_states) + + return KVCache(k=k, v=v) + + def init_kv_copy(self, source_attn): + with torch.no_grad(): + self.to_k.weight.copy_(source_attn.to_k.weight) + self.to_v.weight.copy_(source_attn.to_v.weight) + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x diff --git a/module/diffusers_vae/autoencoder_kl.py b/module/diffusers_vae/autoencoder_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..60e897e59df853491e1fb07d06a9c21fcafd8cac --- /dev/null +++ b/module/diffusers_vae/autoencoder_kl.py @@ -0,0 +1,489 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import FromOriginalVAEMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin +from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin): + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): Sample input size. + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, *optional*, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co./madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + force_upcast: float = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + @apply_forward_hook + def encode( + self, x: torch.FloatTensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for y in range(blend_extent): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for x in range(blend_extent): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Decode a batch of images using a tiled decoder. + + Args: + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is ๐Ÿงช experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is ๐Ÿงช experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) \ No newline at end of file diff --git a/module/diffusers_vae/vae.py b/module/diffusers_vae/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..1315f9dba257555497564e0ce334f6c3d6ac3933 --- /dev/null +++ b/module/diffusers_vae/vae.py @@ -0,0 +1,985 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn + +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils.torch_utils import randn_tensor +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import SpatialNorm +from diffusers.models.unet_2d_blocks import ( + AutoencoderTinyBlock, + UNetMidBlock2D, + get_down_block, + get_up_block, +) + + +@dataclass +class DecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.FloatTensor + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available + options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + double_z (`bool`, *optional*, defaults to `True`): + Whether to double the number of output channels for the last block. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=None, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = self.conv_in(sample) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + if is_torch_version(">=", "1.11.0"): + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, use_reentrant=False + ) + else: + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + # middle + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + else: + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + + sample = self.conv_in(sample) + sample = sample.to(torch.float32) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class UpSample(nn.Module): + r""" + The `UpSample` layer of a variational autoencoder that upsamples its input. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `UpSample` class.""" + x = torch.relu(x) + x = self.deconv(x) + return x + + +class MaskConditionEncoder(nn.Module): + """ + used in AsymmetricAutoencoderKL + """ + + def __init__( + self, + in_ch: int, + out_ch: int = 192, + res_ch: int = 768, + stride: int = 16, + ) -> None: + super().__init__() + + channels = [] + while stride > 1: + stride = stride // 2 + in_ch_ = out_ch * 2 + if out_ch > res_ch: + out_ch = res_ch + if stride == 1: + in_ch_ = res_ch + channels.append((in_ch_, out_ch)) + out_ch *= 2 + + out_channels = [] + for _in_ch, _out_ch in channels: + out_channels.append(_out_ch) + out_channels.append(channels[-1][0]) + + layers = [] + in_ch_ = in_ch + for l in range(len(out_channels)): + out_ch_ = out_channels[l] + if l == 0 or l == 1: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1)) + else: + layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1)) + in_ch_ = out_ch_ + + self.layers = nn.Sequential(*layers) + + def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor: + r"""The forward method of the `MaskConditionEncoder` class.""" + out = {} + for l in range(len(self.layers)): + layer = self.layers[l] + x = layer(x) + out[str(tuple(x.shape))] = x + x = torch.relu(x) + return out + + +class MaskConditionDecoder(nn.Module): + r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's + decoder with a conditioner on the mask and masked image. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + act_fn (`str`, *optional*, defaults to `"silu"`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + norm_type (`str`, *optional*, defaults to `"group"`): + The normalization type to use. Can be either `"group"` or `"spatial"`. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = get_up_block( + up_block_type, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + prev_output_channel=None, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + attention_head_dim=output_channel, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # condition encoder + self.condition_encoder = MaskConditionEncoder( + in_ch=out_channels, + out_ch=block_out_channels[0], + res_ch=block_out_channels[-1], + ) + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward( + self, + z: torch.FloatTensor, + image: Optional[torch.FloatTensor] = None, + mask: Optional[torch.FloatTensor] = None, + latent_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `MaskConditionDecoder` class.""" + sample = z + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), + sample, + latent_embeds, + use_reentrant=False, + ) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.condition_encoder), + masked_image, + mask, + use_reentrant=False, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(up_block), + sample, + latent_embeds, + use_reentrant=False, + ) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.mid_block), sample, latent_embeds + ) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.condition_encoder), + masked_image, + mask, + ) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + else: + # middle + sample = self.mid_block(sample, latent_embeds) + sample = sample.to(upscale_dtype) + + # condition encoder + if image is not None and mask is not None: + masked_image = (1 - mask) * image + im_x = self.condition_encoder(masked_image, mask) + + # up + for up_block in self.up_blocks: + if image is not None and mask is not None: + sample_ = im_x[str(tuple(sample.shape))] + mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest") + sample = sample * mask_ + sample_ * (1 - mask_) + sample = up_block(sample, latent_embeds) + if image is not None and mask is not None: + sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask) + + # post-process + if latent_embeds is None: + sample = self.conv_norm_out(sample) + else: + sample = self.conv_norm_out(sample, latent_embeds) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample + + +class VectorQuantizer(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__( + self, + n_e: int, + vq_embed_dim: int, + beta: float, + remap=None, + unknown_index: str = "random", + sane_index_shape: bool = False, + legacy: bool = True, + ): + super().__init__() + self.n_e = n_e + self.vq_embed_dim = vq_embed_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.used: torch.Tensor + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print( + f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices." + ) + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor: + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]: + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.vq_embed_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1) + + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q: torch.FloatTensor = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor: + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q: torch.FloatTensor = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self) -> torch.Tensor: + return self.mean + + +class EncoderTiny(nn.Module): + r""" + The `EncoderTiny` layer is a simpler version of the `Encoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + act_fn: str, + ): + super().__init__() + + layers = [] + for i, num_block in enumerate(num_blocks): + num_channels = block_out_channels[i] + + if i == 0: + layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1)) + else: + layers.append( + nn.Conv2d( + num_channels, + num_channels, + kernel_size=3, + padding=1, + stride=2, + bias=False, + ) + ) + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1)) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `EncoderTiny` class.""" + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + # scale image from [-1, 1] to [0, 1] to match TAESD convention + x = self.layers(x.add(1).div(2)) + + return x + + +class DecoderTiny(nn.Module): + r""" + The `DecoderTiny` layer is a simpler version of the `Decoder` layer. + + Args: + in_channels (`int`): + The number of input channels. + out_channels (`int`): + The number of output channels. + num_blocks (`Tuple[int, ...]`): + Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to + use. + block_out_channels (`Tuple[int, ...]`): + The number of output channels for each block. + upsampling_scaling_factor (`int`): + The scaling factor to use for upsampling. + act_fn (`str`): + The activation function to use. See `~diffusers.models.activations.get_activation` for available options. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_blocks: Tuple[int, ...], + block_out_channels: Tuple[int, ...], + upsampling_scaling_factor: int, + act_fn: str, + ): + super().__init__() + + layers = [ + nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1), + get_activation(act_fn), + ] + + for i, num_block in enumerate(num_blocks): + is_final_block = i == (len(num_blocks) - 1) + num_channels = block_out_channels[i] + + for _ in range(num_block): + layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn)) + + if not is_final_block: + layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor)) + + conv_out_channel = num_channels if not is_final_block else out_channels + layers.append( + nn.Conv2d( + num_channels, + conv_out_channel, + kernel_size=3, + padding=1, + bias=is_final_block, + ) + ) + + self.layers = nn.Sequential(*layers) + self.gradient_checkpointing = False + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `DecoderTiny` class.""" + # Clamp. + x = torch.tanh(x / 3) * 3 + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False) + else: + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x) + + else: + x = self.layers(x) + + # scale image from [0, 1] to [-1, 1] to match diffusers convention + return x.mul(2).sub(1) \ No newline at end of file diff --git a/module/ip_adapter/attention_processor.py b/module/ip_adapter/attention_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..ed6cf755f7c2b36c81e6527d64bd0cc4e749d696 --- /dev/null +++ b/module/ip_adapter/attention_processor.py @@ -0,0 +1,1467 @@ +# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py +import torch +import torch.nn as nn +import torch.nn.functional as F + +class AdaLayerNorm(nn.Module): + def __init__(self, embedding_dim: int, time_embedding_dim: int = None): + super().__init__() + + if time_embedding_dim is None: + time_embedding_dim = embedding_dim + + self.silu = nn.SiLU() + self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True) + nn.init.zeros_(self.linear.weight) + nn.init.zeros_(self.linear.bias) + + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + + def forward( + self, x: torch.Tensor, timestep_embedding: torch.Tensor + ): + emb = self.linear(self.silu(timestep_embedding)) + shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1) + x = self.norm(x) * (1 + scale) + shift + return x + + +class AttnProcessor(nn.Module): + r""" + Default processor for performing attention-related computations. + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class TA_IPAttnProcessor(nn.Module): + r""" + Attention processor for IP-Adapater. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4): + super().__init__() + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim) + self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + # time-dependent adaLN + ip_key = self.ln_k_ip(ip_key, temb) + ip_value = self.ln_v_ip(ip_value, temb) + + ip_key = attn.head_to_batch_dim(ip_key) + ip_value = attn.head_to_batch_dim(ip_value) + + ip_attention_probs = attn.get_attention_scores(query, ip_key, None) + ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) + ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + external_kv=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if external_kv: + key = torch.cat([key, external_kv.k], axis=1) + value = torch.cat([value, external_kv.v], axis=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class split_AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + time_embedding_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + external_kv=None, + temb=None, + cat_dim=-2, + original_shape=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + # 2d to sequence. + height, width = hidden_states.shape[-2:] + if cat_dim==-2 or cat_dim==2: + hidden_states_0 = hidden_states[:, :, :height//2, :] + hidden_states_1 = hidden_states[:, :, -(height//2):, :] + elif cat_dim==-1 or cat_dim==3: + hidden_states_0 = hidden_states[:, :, :, :width//2] + hidden_states_1 = hidden_states[:, :, :, -(width//2):] + batch_size, channel, height, width = hidden_states_0.shape + hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2) + hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2) + else: + # directly split sqeuence according to concat dim. + single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1] + hidden_states_0 = hidden_states[:, :single_dim*single_dim,:] + hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:] + + hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=1) + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + if external_kv: + key = torch.cat([key, external_kv.k], dim=1) + value = torch.cat([value, external_kv.v], dim=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + # spatially split. + hidden_states_0, hidden_states_1 = hidden_states.chunk(2, dim=1) + + if input_ndim == 4: + hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width) + hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if cat_dim==-2 or cat_dim==2: + hidden_states_pad = torch.zeros(batch_size, channel, 1, width) + elif cat_dim==-1 or cat_dim==3: + hidden_states_pad = torch.zeros(batch_size, channel, height, 1) + hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) + hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim) + assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" + else: + batch_size, sequence_length, inner_dim = hidden_states.shape + hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim) + hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) + hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1) + assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class sep_split_AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size=None, + cross_attention_dim=None, + time_embedding_dim=None, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.ln_k_ref = AdaLayerNorm(hidden_size, time_embedding_dim) + self.ln_v_ref = AdaLayerNorm(hidden_size, time_embedding_dim) + # self.hidden_size = hidden_size + # self.cross_attention_dim = cross_attention_dim + # self.scale = scale + # self.num_tokens = num_tokens + + # self.to_q_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + # self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + # self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + external_kv=None, + temb=None, + cat_dim=-2, + original_shape=None, + ref_scale=1.0, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + # 2d to sequence. + height, width = hidden_states.shape[-2:] + if cat_dim==-2 or cat_dim==2: + hidden_states_0 = hidden_states[:, :, :height//2, :] + hidden_states_1 = hidden_states[:, :, -(height//2):, :] + elif cat_dim==-1 or cat_dim==3: + hidden_states_0 = hidden_states[:, :, :, :width//2] + hidden_states_1 = hidden_states[:, :, :, -(width//2):] + batch_size, channel, height, width = hidden_states_0.shape + hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2) + hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2) + else: + # directly split sqeuence according to concat dim. + single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1] + hidden_states_0 = hidden_states[:, :single_dim*single_dim,:] + hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:] + + batch_size, sequence_length, _ = ( + hidden_states_0.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states_0 = attn.group_norm(hidden_states_0.transpose(1, 2)).transpose(1, 2) + hidden_states_1 = attn.group_norm(hidden_states_1.transpose(1, 2)).transpose(1, 2) + + query_0 = attn.to_q(hidden_states_0) + query_1 = attn.to_q(hidden_states_1) + key_0 = attn.to_k(hidden_states_0) + key_1 = attn.to_k(hidden_states_1) + value_0 = attn.to_v(hidden_states_0) + value_1 = attn.to_v(hidden_states_1) + + # time-dependent adaLN + key_1 = self.ln_k_ref(key_1, temb) + value_1 = self.ln_v_ref(value_1, temb) + + if external_kv: + key_1 = torch.cat([key_1, external_kv.k], dim=1) + value_1 = torch.cat([value_1, external_kv.v], dim=1) + + inner_dim = key_0.shape[-1] + head_dim = inner_dim // attn.heads + + query_0 = query_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + query_1 = query_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_0 = key_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key_1 = key_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_0 = value_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value_1 = value_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states_0 = F.scaled_dot_product_attention( + query_0, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_1 = F.scaled_dot_product_attention( + query_1, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # cross-attn + _hidden_states_0 = F.scaled_dot_product_attention( + query_0, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_0 = hidden_states_0 + ref_scale * _hidden_states_0 * 10 + + # TODO: drop this cross-attn + _hidden_states_1 = F.scaled_dot_product_attention( + query_1, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states_1 = hidden_states_1 + ref_scale * _hidden_states_1 + + hidden_states_0 = hidden_states_0.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_1 = hidden_states_1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states_0 = hidden_states_0.to(query_0.dtype) + hidden_states_1 = hidden_states_1.to(query_1.dtype) + + + # linear proj + hidden_states_0 = attn.to_out[0](hidden_states_0) + hidden_states_1 = attn.to_out[0](hidden_states_1) + # dropout + hidden_states_0 = attn.to_out[1](hidden_states_0) + hidden_states_1 = attn.to_out[1](hidden_states_1) + + + if input_ndim == 4: + hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width) + hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if cat_dim==-2 or cat_dim==2: + hidden_states_pad = torch.zeros(batch_size, channel, 1, width) + elif cat_dim==-1 or cat_dim==3: + hidden_states_pad = torch.zeros(batch_size, channel, height, 1) + hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) + hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim) + assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" + else: + batch_size, sequence_length, inner_dim = hidden_states.shape + hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim) + hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype) + hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1) + assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}" + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AdditiveKV_AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size: int = None, + cross_attention_dim: int = None, + time_embedding_dim: int = None, + additive_scale: float = 1.0, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.additive_scale = additive_scale + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + external_kv=None, + attention_mask=None, + temb=None, + ): + assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + if external_kv: + key = external_kv.k + value = external_kv.v + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + external_attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states + self.additive_scale * external_attn_output + + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class TA_AdditiveKV_AttnProcessor2_0(torch.nn.Module): + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__( + self, + hidden_size: int = None, + cross_attention_dim: int = None, + time_embedding_dim: int = None, + additive_scale: float = 1.0, + ): + super().__init__() + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.ln_k = AdaLayerNorm(hidden_size, time_embedding_dim) + self.ln_v = AdaLayerNorm(hidden_size, time_embedding_dim) + self.additive_scale = additive_scale + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + external_kv=None, + attention_mask=None, + temb=None, + ): + assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + + if external_kv: + key = external_kv.k + value = external_kv.v + + # time-dependent adaLN + key = self.ln_k(key, temb) + value = self.ln_v(value, temb) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + external_attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states + self.additive_scale * external_attn_output + + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + if isinstance(encoder_hidden_states, tuple): + # FIXME: now hard coded to single image prompt. + batch_size, _, hid_dim = encoder_hidden_states[0].shape + ip_tokens = encoder_hidden_states[1][0] + encoder_hidden_states = torch.cat([encoder_hidden_states[0], ip_tokens], dim=1) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class TA_IPAttnProcessor2_0(torch.nn.Module): + r""" + Attention processor for IP-Adapater for PyTorch 2.0. + Args: + hidden_size (`int`): + The hidden size of the attention layer. + cross_attention_dim (`int`): + The number of channels in the `encoder_hidden_states`. + scale (`float`, defaults to 1.0): + the weight scale of image prompt. + num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): + The context length of the image features. + """ + + def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4): + super().__init__() + + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + self.hidden_size = hidden_size + self.cross_attention_dim = cross_attention_dim + self.scale = scale + self.num_tokens = num_tokens + + self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) + self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim) + self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim) + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + external_kv=None, + temb=None, + ): + assert temb is not None, "Timestep embedding is needed for a time-aware attention processor." + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + if not isinstance(encoder_hidden_states, tuple): + # get encoder_hidden_states, ip_hidden_states + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states, ip_hidden_states = ( + encoder_hidden_states[:, :end_pos, :], + encoder_hidden_states[:, end_pos:, :], + ) + else: + # FIXME: now hard coded to single image prompt. + batch_size, _, hid_dim = encoder_hidden_states[0].shape + ip_hidden_states = encoder_hidden_states[1][0] + encoder_hidden_states = encoder_hidden_states[0] + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + if external_kv: + key = torch.cat([key, external_kv.k], axis=1) + value = torch.cat([value, external_kv.v], axis=1) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # for ip-adapter + ip_key = self.to_k_ip(ip_hidden_states) + ip_value = self.to_v_ip(ip_hidden_states) + + # time-dependent adaLN + ip_key = self.ln_k_ip(ip_key, temb) + ip_value = self.ln_v_ip(ip_value, temb) + + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + ip_hidden_states = ip_hidden_states.to(query.dtype) + + hidden_states = hidden_states + self.scale * ip_hidden_states + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +## for controlnet +class CNAttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __init__(self, num_tokens=4): + self.num_tokens = num_tokens + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class CNAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self, num_tokens=4): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + self.num_tokens = num_tokens + + def __call__( + self, + attn, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + temb=None, + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + else: + end_pos = encoder_hidden_states.shape[1] - self.num_tokens + encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, use_external_kv=False): + attn_procs = {} + unet_sd = unet.state_dict() + for name in unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is None: + if use_external_kv: + attn_procs[name] = AdditiveKV_AttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + time_embedding_dim=1280, + ) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor() + else: + attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + else: + if use_adaln: + layer_name = name.split(".processor")[0] + if use_lcm: + weights = { + "to_k_ip.weight": unet_sd[layer_name + ".to_k.base_layer.weight"], + "to_v_ip.weight": unet_sd[layer_name + ".to_v.base_layer.weight"], + } + else: + weights = { + "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + } + attn_procs[name] = TA_IPAttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + num_tokens=ip_adapter_tokens, + time_embedding_dim=1280, + ) if hasattr(F, "scaled_dot_product_attention") else \ + TA_IPAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + num_tokens=ip_adapter_tokens, + time_embedding_dim=1280, + ) + attn_procs[name].load_state_dict(weights, strict=False) + else: + attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor() + + return attn_procs + + +def init_aggregator_attn_proc(unet, use_adaln=False, split_attn=False): + attn_procs = {} + unet_sd = unet.state_dict() + for name in unet.attn_processors.keys(): + # get layer name and hidden dim + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + # init attn proc + if split_attn: + # layer_name = name.split(".processor")[0] + # weights = { + # "to_q_ref.weight": unet_sd[layer_name + ".to_q.weight"], + # "to_k_ref.weight": unet_sd[layer_name + ".to_k.weight"], + # "to_v_ref.weight": unet_sd[layer_name + ".to_v.weight"], + # } + attn_procs[name] = ( + sep_split_AttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=hidden_size, + time_embedding_dim=1280, + ) + if use_adaln + else split_AttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + time_embedding_dim=1280, + ) + ) + # attn_procs[name].load_state_dict(weights, strict=False) + else: + attn_procs[name] = ( + AttnProcessor2_0( + hidden_size=hidden_size, + cross_attention_dim=hidden_size, + ) + if hasattr(F, "scaled_dot_product_attention") + else AttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=hidden_size, + ) + ) + + return attn_procs diff --git a/module/ip_adapter/ip_adapter.py b/module/ip_adapter/ip_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..10f01d4f38f95fc930af2db66bef72757cc54455 --- /dev/null +++ b/module/ip_adapter/ip_adapter.py @@ -0,0 +1,236 @@ +import os +import torch +from typing import List +from collections import namedtuple, OrderedDict + +def is_torch2_available(): + return hasattr(torch.nn.functional, "scaled_dot_product_attention") + +if is_torch2_available(): + from .attention_processor import ( + AttnProcessor2_0 as AttnProcessor, + ) + from .attention_processor import ( + CNAttnProcessor2_0 as CNAttnProcessor, + ) + from .attention_processor import ( + IPAttnProcessor2_0 as IPAttnProcessor, + ) + from .attention_processor import ( + TA_IPAttnProcessor2_0 as TA_IPAttnProcessor, + ) +else: + from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor, TA_IPAttnProcessor + + +class ImageProjModel(torch.nn.Module): + """Projection Model""" + + def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4): + super().__init__() + + self.cross_attention_dim = cross_attention_dim + self.clip_extra_context_tokens = clip_extra_context_tokens + self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) + self.norm = torch.nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds): + embeds = image_embeds + clip_extra_context_tokens = self.proj(embeds).reshape( + -1, self.clip_extra_context_tokens, self.cross_attention_dim + ) + clip_extra_context_tokens = self.norm(clip_extra_context_tokens) + return clip_extra_context_tokens + + +class MLPProjModel(torch.nn.Module): + """SD model with image prompt""" + def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280): + super().__init__() + + self.proj = torch.nn.Sequential( + torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), + torch.nn.GELU(), + torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), + torch.nn.LayerNorm(cross_attention_dim) + ) + + def forward(self, image_embeds): + clip_extra_context_tokens = self.proj(image_embeds) + return clip_extra_context_tokens + + +class MultiIPAdapterImageProjection(torch.nn.Module): + def __init__(self, IPAdapterImageProjectionLayers): + super().__init__() + self.image_projection_layers = torch.nn.ModuleList(IPAdapterImageProjectionLayers) + + def forward(self, image_embeds: List[torch.FloatTensor]): + projected_image_embeds = [] + + # currently, we accept `image_embeds` as + # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] + # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] + if not isinstance(image_embeds, list): + image_embeds = [image_embeds.unsqueeze(1)] + + if len(image_embeds) != len(self.image_projection_layers): + raise ValueError( + f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" + ) + + for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): + batch_size, num_images = image_embed.shape[0], image_embed.shape[1] + image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) + image_embed = image_projection_layer(image_embed) + # image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) + + projected_image_embeds.append(image_embed) + + return projected_image_embeds + + +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): + super().__init__() + self.unet = unet + self.image_proj = image_proj_model + self.ip_adapter = adapter_modules + + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): + ip_tokens = self.image_proj(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()])) + + state_dict = torch.load(ckpt_path, map_location="cpu") + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + state_dict = revise_state_dict(state_dict) + + # Load state dict for image_proj_model and adapter_modules + self.image_proj.load_state_dict(state_dict["image_proj"], strict=True) + self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=True) + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()])) + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" + + +class IPAdapterPlus(torch.nn.Module): + """IP-Adapter""" + def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): + super().__init__() + self.unet = unet + self.image_proj = image_proj_model + self.ip_adapter = adapter_modules + + if ckpt_path is not None: + self.load_from_checkpoint(ckpt_path) + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): + ip_tokens = self.image_proj(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample + return noise_pred + + def load_from_checkpoint(self, ckpt_path: str): + # Calculate original checksums + orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()])) + orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()])) + org_unet_sum = [] + for attn_name, attn_proc in self.unet.attn_processors.items(): + if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)): + org_unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()]))) + org_unet_sum = torch.sum(torch.stack(org_unet_sum)) + + state_dict = torch.load(ckpt_path, map_location="cpu") + keys = list(state_dict.keys()) + if keys != ["image_proj", "ip_adapter"]: + state_dict = revise_state_dict(state_dict) + + # Check if 'latents' exists in both the saved state_dict and the current model's state_dict + strict_load_image_proj_model = True + if "latents" in state_dict["image_proj"] and "latents" in self.image_proj.state_dict(): + # Check if the shapes are mismatched + if state_dict["image_proj"]["latents"].shape != self.image_proj.state_dict()["latents"].shape: + print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.") + print("Removing 'latents' from checkpoint and loading the rest of the weights.") + del state_dict["image_proj"]["latents"] + strict_load_image_proj_model = False + + # Load state dict for image_proj_model and adapter_modules + self.image_proj.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model) + missing_key, unexpected_key = self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=False) + if len(missing_key) > 0: + for ms in missing_key: + if "ln" not in ms: + raise ValueError(f"Missing key in adapter_modules: {len(missing_key)}") + if len(unexpected_key) > 0: + raise ValueError(f"Unexpected key in adapter_modules: {len(unexpected_key)}") + + # Calculate new checksums + new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()])) + new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()])) + + # Verify if the weights loaded to unet + unet_sum = [] + for attn_name, attn_proc in self.unet.attn_processors.items(): + if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)): + unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()]))) + unet_sum = torch.sum(torch.stack(unet_sum)) + + assert org_unet_sum != unet_sum, "Weights of adapter_modules in unet did not change!" + assert (unet_sum - new_adapter_sum < 1e-4), "Weights of adapter_modules did not load to unet!" + + # Verify if the weights have changed + assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" + assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_mod`ules did not change!" + + +class IPAdapterXL(IPAdapter): + """SDXL""" + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds): + ip_tokens = self.image_proj(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample + return noise_pred + + +class IPAdapterPlusXL(IPAdapterPlus): + """IP-Adapter with fine-grained features""" + + def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds): + ip_tokens = self.image_proj(image_embeds) + encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) + # Predict the noise residual + noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample + return noise_pred + + +class IPAdapterFull(IPAdapterPlus): + """IP-Adapter with full features""" + + def init_proj(self): + image_proj_model = MLPProjModel( + cross_attention_dim=self.pipe.unet.config.cross_attention_dim, + clip_embeddings_dim=self.image_encoder.config.hidden_size, + ).to(self.device, dtype=torch.float16) + return image_proj_model diff --git a/module/ip_adapter/resampler.py b/module/ip_adapter/resampler.py new file mode 100644 index 0000000000000000000000000000000000000000..72295f90b1e9fdfe717de9af317fc55e5d466a09 --- /dev/null +++ b/module/ip_adapter/resampler.py @@ -0,0 +1,158 @@ +# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py +# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py + +import math + +import torch +import torch.nn as nn +from einops import rearrange +from einops.layers.torch import Rearrange + + +# FFN +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + + +class Resampler(nn.Module): + def __init__( + self, + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=64, + embedding_dim=768, + output_dim=1024, + ff_mult=4, + max_seq_len: int = 257, # CLIP tokens + CLS token + apply_pos_emb: bool = False, + num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence + ): + super().__init__() + self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None + + self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) + + self.proj_in = nn.Linear(embedding_dim, dim) + + self.proj_out = nn.Linear(dim, output_dim) + self.norm_out = nn.LayerNorm(output_dim) + + self.to_latents_from_mean_pooled_seq = ( + nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, dim * num_latents_mean_pooled), + Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled), + ) + if num_latents_mean_pooled > 0 + else None + ) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, x): + if self.pos_emb is not None: + n, device = x.shape[1], x.device + pos_emb = self.pos_emb(torch.arange(n, device=device)) + x = x + pos_emb + + latents = self.latents.repeat(x.size(0), 1, 1) + + x = self.proj_in(x) + + if self.to_latents_from_mean_pooled_seq: + meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool)) + meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) + latents = torch.cat((meanpooled_latents, latents), dim=-2) + + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + +def masked_mean(t, *, dim, mask=None): + if mask is None: + return t.mean(dim=dim) + + denom = mask.sum(dim=dim, keepdim=True) + mask = rearrange(mask, "b n -> b n 1") + masked_t = t.masked_fill(~mask, 0.0) + + return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) diff --git a/module/ip_adapter/utils.py b/module/ip_adapter/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..64c45cd85ac8aa2a03ad5fa66429ddd46824cd69 --- /dev/null +++ b/module/ip_adapter/utils.py @@ -0,0 +1,248 @@ +import torch +from collections import namedtuple, OrderedDict +from safetensors import safe_open +from .attention_processor import init_attn_proc +from .ip_adapter import MultiIPAdapterImageProjection +from .resampler import Resampler +from transformers import ( + AutoModel, AutoImageProcessor, + CLIPVisionModelWithProjection, CLIPImageProcessor) + + +def init_adapter_in_unet( + unet, + image_proj_model=None, + pretrained_model_path_or_dict=None, + adapter_tokens=64, + embedding_dim=None, + use_lcm=False, + use_adaln=True, + ): + device = unet.device + dtype = unet.dtype + if image_proj_model is None: + assert embedding_dim is not None, "embedding_dim must be provided if image_proj_model is None." + image_proj_model = Resampler( + embedding_dim=embedding_dim, + output_dim=unet.config.cross_attention_dim, + num_queries=adapter_tokens, + ) + if pretrained_model_path_or_dict is not None: + if not isinstance(pretrained_model_path_or_dict, dict): + if pretrained_model_path_or_dict.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(pretrained_model_path_or_dict, framework="pt", device=unet.device) as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(pretrained_model_path_or_dict, map_location=unet.device) + else: + state_dict = pretrained_model_path_or_dict + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + state_dict = revise_state_dict(state_dict) + + # Creat IP cross-attention in unet. + attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln) + unet.set_attn_processor(attn_procs) + + # Load pretrinaed model if needed. + if pretrained_model_path_or_dict is not None: + if "ip_adapter" in state_dict.keys(): + adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) + missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False) + for mk in missing: + if "ln" not in mk: + raise ValueError(f"Missing keys in adapter_modules: {missing}") + if "image_proj" in state_dict.keys(): + image_proj_model.load_state_dict(state_dict["image_proj"]) + + # Load image projectors into iterable ModuleList. + image_projection_layers = [] + image_projection_layers.append(image_proj_model) + unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + + # Adjust unet config to handle addtional ip hidden states. + unet.config.encoder_hid_dim_type = "ip_image_proj" + unet.to(dtype=dtype, device=device) + + +def load_adapter_to_pipe( + pipe, + pretrained_model_path_or_dict, + image_encoder_or_path=None, + feature_extractor_or_path=None, + use_clip_encoder=False, + adapter_tokens=64, + use_lcm=False, + use_adaln=True, + ): + + if not isinstance(pretrained_model_path_or_dict, dict): + if pretrained_model_path_or_dict.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device) + else: + state_dict = pretrained_model_path_or_dict + keys = list(state_dict.keys()) + if "image_proj" not in keys and "ip_adapter" not in keys: + state_dict = revise_state_dict(state_dict) + + # load CLIP image encoder here if it has not been registered to the pipeline yet + if image_encoder_or_path is not None: + if isinstance(image_encoder_or_path, str): + feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path + + image_encoder_or_path = ( + CLIPVisionModelWithProjection.from_pretrained( + image_encoder_or_path + ) if use_clip_encoder else + AutoModel.from_pretrained(image_encoder_or_path) + ) + + if feature_extractor_or_path is not None: + if isinstance(feature_extractor_or_path, str): + feature_extractor_or_path = ( + CLIPImageProcessor() if use_clip_encoder else + AutoImageProcessor.from_pretrained(feature_extractor_or_path) + ) + + # create image encoder if it has not been registered to the pipeline yet + if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None: + image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype) + pipe.register_modules(image_encoder=image_encoder) + else: + image_encoder = pipe.image_encoder + + # create feature extractor if it has not been registered to the pipeline yet + if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None: + feature_extractor = feature_extractor_or_path + pipe.register_modules(feature_extractor=feature_extractor) + else: + feature_extractor = pipe.feature_extractor + + # load adapter into unet + unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet + attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln) + unet.set_attn_processor(attn_procs) + image_proj_model = Resampler( + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + num_queries=adapter_tokens, + ) + + # Load pretrinaed model if needed. + if "ip_adapter" in state_dict.keys(): + adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) + missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False) + for mk in missing: + if "ln" not in mk: + raise ValueError(f"Missing keys in adapter_modules: {missing}") + if "image_proj" in state_dict.keys(): + image_proj_model.load_state_dict(state_dict["image_proj"]) + + # convert IP-Adapter Image Projection layers to diffusers + image_projection_layers = [] + image_projection_layers.append(image_proj_model) + unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + + # Adjust unet config to handle addtional ip hidden states. + unet.config.encoder_hid_dim_type = "ip_image_proj" + unet.to(dtype=pipe.dtype, device=pipe.device) + + +def revise_state_dict(old_state_dict_or_path, map_location="cpu"): + new_state_dict = OrderedDict() + new_state_dict["image_proj"] = OrderedDict() + new_state_dict["ip_adapter"] = OrderedDict() + if isinstance(old_state_dict_or_path, str): + old_state_dict = torch.load(old_state_dict_or_path, map_location=map_location) + else: + old_state_dict = old_state_dict_or_path + for name, weight in old_state_dict.items(): + if name.startswith("image_proj_model."): + new_state_dict["image_proj"][name[len("image_proj_model."):]] = weight + elif name.startswith("adapter_modules."): + new_state_dict["ip_adapter"][name[len("adapter_modules."):]] = weight + return new_state_dict + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image +def encode_image(image_encoder, feature_extractor, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + return image_enc_hidden_states + else: + if isinstance(image_encoder, CLIPVisionModelWithProjection): + # CLIP image encoder. + image_embeds = image_encoder(image).image_embeds + else: + # DINO image encoder. + image_embeds = image_encoder(image).last_hidden_state + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + return image_embeds + + +def prepare_training_image_embeds( + image_encoder, feature_extractor, + ip_adapter_image, ip_adapter_image_embeds, + device, drop_rate, output_hidden_state, idx_to_replace=None +): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + # if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers): + # raise ValueError( + # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + # ) + + image_embeds = [] + for single_ip_adapter_image in ip_adapter_image: + if idx_to_replace is None: + idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate + zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image) + single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace] + single_image_embeds = encode_image( + image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds \ No newline at end of file diff --git a/module/min_sdxl.py b/module/min_sdxl.py new file mode 100644 index 0000000000000000000000000000000000000000..fd94b03dc8f7a1484888085b440a66f99a4b0275 --- /dev/null +++ b/module/min_sdxl.py @@ -0,0 +1,915 @@ +# Modified from minSDXL by Simo Ryu: +# https://github.com/cloneofsimo/minSDXL , +# which is in turn modified from the original code of: +# https://github.com/huggingface/diffusers +# So has APACHE 2.0 license + +from typing import Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import inspect + +from collections import namedtuple + +from torch.fft import fftn, fftshift, ifftn, ifftshift + +from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0 + +# Implementation of FreeU for minsdxl + +def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": + """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497). + + This version of the method comes from here: + https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706 + """ + x = x_in + B, C, H, W = x.shape + + # Non-power of 2 images must be float32 + if (W & (W - 1)) != 0 or (H & (H - 1)) != 0: + x = x.to(dtype=torch.float32) + + # FFT + x_freq = fftn(x, dim=(-2, -1)) + x_freq = fftshift(x_freq, dim=(-2, -1)) + + B, C, H, W = x_freq.shape + mask = torch.ones((B, C, H, W), device=x.device) + + crow, ccol = H // 2, W // 2 + mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale + x_freq = x_freq * mask + + # IFFT + x_freq = ifftshift(x_freq, dim=(-2, -1)) + x_filtered = ifftn(x_freq, dim=(-2, -1)).real + + return x_filtered.to(dtype=x_in.dtype) + + +def apply_freeu( + resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs): + """Applies the FreeU mechanism as introduced in https: + //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. + + Args: + resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied. + hidden_states (`torch.Tensor`): Inputs to the underlying block. + res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block. + s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features. + s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + if resolution_idx == 0: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"]) + if resolution_idx == 1: + num_half_channels = hidden_states.shape[1] // 2 + hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] + res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) + + return hidden_states, res_hidden_states + +# Diffusers-style LoRA to keep everything in the min_sdxl.py file + +class LoRALinearLayer(nn.Module): + r""" + A linear layer that is used with LoRA. + + Parameters: + in_features (`int`): + Number of input features. + out_features (`int`): + Number of output features. + rank (`int`, `optional`, defaults to 4): + The rank of the LoRA layer. + network_alpha (`float`, `optional`, defaults to `None`): + The value of the network alpha used for stable learning and preventing underflow. This value has the same + meaning as the `--network_alpha` option in the kohya-ss trainer script. See + https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + device (`torch.device`, `optional`, defaults to `None`): + The device to use for the layer's weights. + dtype (`torch.dtype`, `optional`, defaults to `None`): + The dtype to use for the layer's weights. + """ + + def __init__( + self, + in_features: int, + out_features: int, + rank: int = 4, + network_alpha: Optional[float] = None, + device: Optional[Union[torch.device, str]] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + + self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) + self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning + self.network_alpha = network_alpha + self.rank = rank + self.out_features = out_features + self.in_features = in_features + + nn.init.normal_(self.down.weight, std=1 / rank) + nn.init.zeros_(self.up.weight) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + orig_dtype = hidden_states.dtype + dtype = self.down.weight.dtype + + down_hidden_states = self.down(hidden_states.to(dtype)) + up_hidden_states = self.up(down_hidden_states) + + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + + return up_hidden_states.to(orig_dtype) + +class LoRACompatibleLinear(nn.Linear): + """ + A Linear layer that can be used with LoRA. + """ + + def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs): + super().__init__(*args, **kwargs) + self.lora_layer = lora_layer + + def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]): + self.lora_layer = lora_layer + + def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False): + if self.lora_layer is None: + return + + dtype, device = self.weight.data.dtype, self.weight.data.device + + w_orig = self.weight.data.float() + w_up = self.lora_layer.up.weight.data.float() + w_down = self.lora_layer.down.weight.data.float() + + if self.lora_layer.network_alpha is not None: + w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank + + fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + + if safe_fusing and torch.isnan(fused_weight).any().item(): + raise ValueError( + "This LoRA weight seems to be broken. " + f"Encountered NaN values when trying to fuse LoRA weights for {self}." + "LoRA weights will not be fused." + ) + + self.weight.data = fused_weight.to(device=device, dtype=dtype) + + # we can drop the lora layer now + self.lora_layer = None + + # offload the up and down matrices to CPU to not blow the memory + self.w_up = w_up.cpu() + self.w_down = w_down.cpu() + self._lora_scale = lora_scale + + def _unfuse_lora(self): + if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None): + return + + fused_weight = self.weight.data + dtype, device = fused_weight.dtype, fused_weight.device + + w_up = self.w_up.to(device=device).float() + w_down = self.w_down.to(device).float() + + unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0]) + self.weight.data = unfused_weight.to(device=device, dtype=dtype) + + self.w_up = None + self.w_down = None + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + if self.lora_layer is None: + out = super().forward(hidden_states) + return out + else: + out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states)) + return out + +class Timesteps(nn.Module): + def __init__(self, num_channels: int = 320): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + half_dim = self.num_channels // 2 + exponent = -math.log(10000) * torch.arange( + half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - 0.0) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + sin_emb = torch.sin(emb) + cos_emb = torch.cos(emb) + emb = torch.cat([cos_emb, sin_emb], dim=-1) + + return emb + + +class TimestepEmbedding(nn.Module): + def __init__(self, in_features, out_features): + super(TimestepEmbedding, self).__init__() + self.linear_1 = nn.Linear(in_features, out_features, bias=True) + self.act = nn.SiLU() + self.linear_2 = nn.Linear(out_features, out_features, bias=True) + + def forward(self, sample): + sample = self.linear_1(sample) + sample = self.act(sample) + sample = self.linear_2(sample) + + return sample + + +class ResnetBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, conv_shortcut=True): + super(ResnetBlock2D, self).__init__() + self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.time_emb_proj = nn.Linear(1280, out_channels, bias=True) + self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True) + self.dropout = nn.Dropout(p=0.0, inplace=False) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.nonlinearity = nn.SiLU() + self.conv_shortcut = None + if conv_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1 + ) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + hidden_states = hidden_states + temb + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Attention(nn.Module): + def __init__( + self, inner_dim, cross_attention_dim=None, num_heads=None, dropout=0.0, processor=None, scale_qk=True + ): + super(Attention, self).__init__() + if num_heads is None: + self.head_dim = 64 + self.num_heads = inner_dim // self.head_dim + else: + self.num_heads = num_heads + self.head_dim = inner_dim // num_heads + + self.scale = self.head_dim**-0.5 + if cross_attention_dim is None: + cross_attention_dim = inner_dim + self.to_q = LoRACompatibleLinear(inner_dim, inner_dim, bias=False) + self.to_k = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False) + self.to_v = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList( + [LoRACompatibleLinear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)] + ) + + self.scale_qk = scale_qk + if processor is None: + processor = ( + AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() + ) + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters] + if len(unused_kwargs) > 0: + print( + f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + + def orig_forward(self, hidden_states, encoder_hidden_states=None): + q = self.to_q(hidden_states) + k = ( + self.to_k(encoder_hidden_states) + if encoder_hidden_states is not None + else self.to_k(hidden_states) + ) + v = ( + self.to_v(encoder_hidden_states) + if encoder_hidden_states is not None + else self.to_v(hidden_states) + ) + b, t, c = q.size() + + q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2) + + # scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale + # attn_weights = torch.softmax(scores, dim=-1) + # attn_output = torch.matmul(attn_weights, v) + + attn_output = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale, + ) + + attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c) + + for layer in self.to_out: + attn_output = layer(attn_output) + + return attn_output + + def set_processor(self, processor) -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + print(f"You are removing possibly trained weights of {self.processor} with {processor}") + self._modules.pop("processor") + + self.processor = processor + + def get_processor(self, return_deprecated_lora: bool = False): + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not possible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict()) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict()) + lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict()) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + +class GEGLU(nn.Module): + def __init__(self, in_features, out_features): + super(GEGLU, self).__init__() + self.proj = nn.Linear(in_features, out_features * 2, bias=True) + + def forward(self, x): + x_proj = self.proj(x) + x1, x2 = x_proj.chunk(2, dim=-1) + return x1 * torch.nn.functional.gelu(x2) + + +class FeedForward(nn.Module): + def __init__(self, in_features, out_features): + super(FeedForward, self).__init__() + + self.net = nn.ModuleList( + [ + GEGLU(in_features, out_features * 4), + nn.Dropout(p=0.0, inplace=False), + nn.Linear(out_features * 4, out_features, bias=True), + ] + ) + + def forward(self, x): + for layer in self.net: + x = layer(x) + return x + + +class BasicTransformerBlock(nn.Module): + def __init__(self, hidden_size): + super(BasicTransformerBlock, self).__init__() + self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) + self.attn1 = Attention(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) + self.attn2 = Attention(hidden_size, 2048) + self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True) + self.ff = FeedForward(hidden_size, hidden_size) + + def forward(self, x, encoder_hidden_states=None): + residual = x + + x = self.norm1(x) + x = self.attn1(x) + x = x + residual + + residual = x + + x = self.norm2(x) + if encoder_hidden_states is not None: + x = self.attn2(x, encoder_hidden_states) + else: + x = self.attn2(x) + x = x + residual + + residual = x + + x = self.norm3(x) + x = self.ff(x) + x = x + residual + return x + + +class Transformer2DModel(nn.Module): + def __init__(self, in_channels, out_channels, n_layers): + super(Transformer2DModel, self).__init__() + self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True) + self.proj_in = nn.Linear(in_channels, out_channels, bias=True) + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(out_channels) for _ in range(n_layers)] + ) + self.proj_out = nn.Linear(out_channels, out_channels, bias=True) + + def forward(self, hidden_states, encoder_hidden_states=None): + batch, _, height, width = hidden_states.shape + res = hidden_states + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * width, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states) + + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, width, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + return hidden_states + res + + +class Downsample2D(nn.Module): + def __init__(self, in_channels, out_channels): + super(Downsample2D, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, padding=1 + ) + + def forward(self, x): + return self.conv(x) + + +class Upsample2D(nn.Module): + def __init__(self, in_channels, out_channels): + super(Upsample2D, self).__init__() + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + return self.conv(x) + + +class DownBlock2D(nn.Module): + def __init__(self, in_channels, out_channels): + super(DownBlock2D, self).__init__() + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(in_channels, out_channels, conv_shortcut=False), + ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), + ] + ) + self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)]) + + def forward(self, hidden_states, temb): + output_states = [] + for module in self.resnets: + hidden_states = module(hidden_states, temb) + output_states.append(hidden_states) + + hidden_states = self.downsamplers[0](hidden_states) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True): + super(CrossAttnDownBlock2D, self).__init__() + self.attentions = nn.ModuleList( + [ + Transformer2DModel(out_channels, out_channels, n_layers), + Transformer2DModel(out_channels, out_channels, n_layers), + ] + ) + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(in_channels, out_channels), + ResnetBlock2D(out_channels, out_channels, conv_shortcut=False), + ] + ) + self.downsamplers = None + if has_downsamplers: + self.downsamplers = nn.ModuleList( + [Downsample2D(out_channels, out_channels)] + ) + + def forward(self, hidden_states, temb, encoder_hidden_states): + output_states = [] + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + output_states.append(hidden_states) + + if self.downsamplers is not None: + hidden_states = self.downsamplers[0](hidden_states) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, prev_output_channel, n_layers): + super(CrossAttnUpBlock2D, self).__init__() + self.attentions = nn.ModuleList( + [ + Transformer2DModel(out_channels, out_channels, n_layers), + Transformer2DModel(out_channels, out_channels, n_layers), + Transformer2DModel(out_channels, out_channels, n_layers), + ] + ) + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(prev_output_channel + out_channels, out_channels), + ResnetBlock2D(2 * out_channels, out_channels), + ResnetBlock2D(out_channels + in_channels, out_channels), + ] + ) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)]) + + def forward( + self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, prev_output_channel): + super(UpBlock2D, self).__init__() + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(out_channels + prev_output_channel, out_channels), + ResnetBlock2D(out_channels * 2, out_channels), + ResnetBlock2D(out_channels + in_channels, out_channels), + ] + ) + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + and getattr(self, "resolution_idx", None) + ) + + for resnet in self.resnets: + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__(self, in_features): + super(UNetMidBlock2DCrossAttn, self).__init__() + self.attentions = nn.ModuleList( + [Transformer2DModel(in_features, in_features, n_layers=10)] + ) + self.resnets = nn.ModuleList( + [ + ResnetBlock2D(in_features, in_features, conv_shortcut=False), + ResnetBlock2D(in_features, in_features, conv_shortcut=False), + ] + ) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNet2DConditionModel(nn.Module): + def __init__(self): + super(UNet2DConditionModel, self).__init__() + + # This is needed to imitate huggingface config behavior + # has nothing to do with the model itself + # remove this if you don't use diffuser's pipeline + self.config = namedtuple( + "config", "in_channels addition_time_embed_dim sample_size" + ) + self.config.in_channels = 4 + self.config.addition_time_embed_dim = 256 + self.config.sample_size = 128 + + self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1) + self.time_proj = Timesteps() + self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280) + self.add_time_proj = Timesteps(256) + self.add_embedding = TimestepEmbedding(in_features=2816, out_features=1280) + self.down_blocks = nn.ModuleList( + [ + DownBlock2D(in_channels=320, out_channels=320), + CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=2), + CrossAttnDownBlock2D( + in_channels=640, + out_channels=1280, + n_layers=10, + has_downsamplers=False, + ), + ] + ) + self.up_blocks = nn.ModuleList( + [ + CrossAttnUpBlock2D( + in_channels=640, + out_channels=1280, + prev_output_channel=1280, + n_layers=10, + ), + CrossAttnUpBlock2D( + in_channels=320, + out_channels=640, + prev_output_channel=1280, + n_layers=2, + ), + UpBlock2D(in_channels=320, out_channels=320, prev_output_channel=640), + ] + ) + self.mid_block = UNetMidBlock2DCrossAttn(1280) + self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1) + + def forward( + self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs + ): + # Implement the forward pass through the model + timesteps = timesteps.expand(sample.shape[0]) + t_emb = self.time_proj(timesteps).to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb) + + text_embeds = added_cond_kwargs.get("text_embeds") + time_ids = added_cond_kwargs.get("time_ids") + + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb + + sample = self.conv_in(sample) + + # 3. down + s0 = sample + sample, [s1, s2, s3] = self.down_blocks[0]( + sample, + temb=emb, + ) + + sample, [s4, s5, s6] = self.down_blocks[1]( + sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + + sample, [s7, s8] = self.down_blocks[2]( + sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + ) + + # 4. mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states + ) + + # 5. up + sample = self.up_blocks[0]( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=[s6, s7, s8], + encoder_hidden_states=encoder_hidden_states, + ) + + sample = self.up_blocks[1]( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=[s3, s4, s5], + encoder_hidden_states=encoder_hidden_states, + ) + + sample = self.up_blocks[2]( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=[s0, s1, s2], + ) + + # 6. post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return [sample] \ No newline at end of file diff --git a/module/unet/unet_2d_ZeroSFT.py b/module/unet/unet_2d_ZeroSFT.py new file mode 100644 index 0000000000000000000000000000000000000000..c91286fcf0d119c5e04c202b75a27cf5e8fdbd59 --- /dev/null +++ b/module/unet/unet_2d_ZeroSFT.py @@ -0,0 +1,1397 @@ +# Copy from diffusers.models.unets.unet_2d_condition.py + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from .unet_2d_ZeroSFT_blocks import ( + get_down_block, + get_mid_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ZeroConv(nn.Module): + def __init__(self, label_nc, norm_nc, mask=False): + super().__init__() + self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0)) + self.mask = mask + + def forward(self, c, h, h_ori=None): + # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): + if not self.mask: + h = h + self.zero_conv(c) + else: + h = h + self.zero_conv(c) * torch.zeros_like(h) + if h_ori is not None: + h = torch.cat([h_ori, h], dim=1) + return h + + +class ZeroSFT(nn.Module): + def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False): + super().__init__() + + # param_free_norm_type = str(parsed.group(1)) + ks = 3 + pw = ks // 2 + + self.mask = mask + self.norm = norm + self.pre_concat = bool(concat_channels != 0) + if self.norm: + self.param_free_norm = torch.nn.GroupNorm(num_groups=32, num_channels=norm_nc + concat_channels) + else: + self.param_free_norm = nn.Identity() + + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), + nn.SiLU() + ) + self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)) + self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)) + + self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0)) + + def forward(self, down_block_res_samples, h_ori=None, control_scale=1.0, mask=False): + mask = mask or self.mask + assert mask is False + if self.pre_concat: + assert h_ori is not None + + c,h = down_block_res_samples + if h_ori is not None: + h_raw = torch.cat([h_ori, h], dim=1) + else: + h_raw = h + + if self.mask: + h = h + self.zero_conv(c) * torch.zeros_like(h) + else: + h = h + self.zero_conv(c) + if h_ori is not None and self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + actv = self.mlp_shared(c) + gamma = self.zero_mul(actv) + beta = self.zero_add(actv) + if self.mask: + gamma = gamma * torch.zeros_like(gamma) + beta = beta * torch.zeros_like(beta) + # h = h + self.param_free_norm(h) * gamma + beta + h = self.param_free_norm(h) * (gamma + 1) + beta + if h_ori is not None and not self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + return h * control_scale + h_raw * (1 - control_scale) + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DZeroSFTModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"] + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) + self.mid_zero_SFT = ZeroSFT(block_out_channels[-1],block_out_channels[-1],0) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: Union[int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) + are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is ๐Ÿงช experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is ๐Ÿงช experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, + otherwise a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated + # to the internal blocks and will raise deprecation warnings. this will be confusing for our users. + if cross_attention_kwargs is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + lora_scale = cross_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_additional_residual, down_block_res_sample in zip( + down_block_additional_residuals, down_block_res_samples + ): + down_block_res_sample_tuple = (down_block_additional_residual, down_block_res_sample) + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample_tuple,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = self.mid_zero_SFT((mid_block_additional_residual, sample),) + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/module/unet/unet_2d_ZeroSFT_blocks.py b/module/unet/unet_2d_ZeroSFT_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..66b17c7b0a0f0b0dab77d05d0a10a1d564bae763 --- /dev/null +++ b/module/unet/unet_2d_ZeroSFT_blocks.py @@ -0,0 +1,3862 @@ +# Copy from diffusers.models.unet.unet_2d_blocks.py + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import deprecate, is_torch_version, logging +from diffusers.utils.torch_utils import apply_freeu +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.normalization import AdaGroupNorm +from diffusers.models.resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) +from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.transformers.transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + return UNetMidBlock2DSimpleCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + return UNetMidBlock2D( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + num_layers=0, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + return None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warning( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +class ZeroConv(nn.Module): + def __init__(self, label_nc, norm_nc, mask=False): + super().__init__() + self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0)) + self.mask = mask + + def forward(self, c, h, h_ori=None): + # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): + if not self.mask: + h = h + self.zero_conv(c) + else: + h = h + self.zero_conv(c) * torch.zeros_like(h) + if h_ori is not None: + h = torch.cat([h_ori, h], dim=1) + return h + + +class ZeroSFT(nn.Module): + def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False): + super().__init__() + + # param_free_norm_type = str(parsed.group(1)) + ks = 3 + pw = ks // 2 + + self.mask = mask + self.norm = norm + self.pre_concat = bool(concat_channels != 0) + if self.norm: + self.param_free_norm = torch.nn.GroupNorm(num_groups=32, num_channels=norm_nc + concat_channels) + else: + self.param_free_norm = nn.Identity() + + nhidden = 128 + + self.mlp_shared = nn.Sequential( + nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), + nn.SiLU() + ) + self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)) + self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw)) + + self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0)) + + def forward(self, down_block_res_samples, h_ori=None, control_scale=1.0, mask=False): + mask = mask or self.mask + assert mask is False + if self.pre_concat: + assert h_ori is not None + + c,h = down_block_res_samples + if h_ori is not None: + h_raw = torch.cat([h_ori, h], dim=1) + else: + h_raw = h + + if self.mask: + h = h + self.zero_conv(c) * torch.zeros_like(h) + else: + h = h + self.zero_conv(c) + if h_ori is not None and self.pre_concat: + h_ori_c = h_ori.shape[1] + h_c = h.shape[1] + h = torch.cat([h_ori, h], dim=1) + actv = self.mlp_shared(c) + gamma = self.zero_mul(actv) + beta = self.zero_add(actv) + if self.mask: + gamma = gamma * torch.zeros_like(gamma) + beta = beta * torch.zeros_like(beta) + # gamma_ori, gamma_res = torch.split(gamma, [h_ori_c, h_c], dim=1) + # beta_ori, beta_res = torch.split(beta, [h_ori_c, h_c], dim=1) + # print(gamma_ori.mean(), gamma_res.mean(), beta_ori.mean(), beta_res.mean()) + # h = h + self.param_free_norm(h) * gamma + beta + h = self.param_free_norm(h) * (gamma + 1) + beta + # sample_ori, sample_res = torch.split(h, [h_ori_c, h_c], dim=1) + # print(sample_ori.mean(), sample_res.mean()) + if h_ori is not None and not self.pre_concat: + h = torch.cat([h_ori, h], dim=1) + return h * control_scale + h_raw * (1 - control_scale) + + +class AutoencoderTinyBlock(nn.Module): + """ + Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU + blocks. + + Args: + in_channels (`int`): The number of input channels. + out_channels (`int`): The number of output channels. + act_fn (`str`): + ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. + + Returns: + `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to + `out_channels`. + """ + + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels + else nn.Identity() + ) + self.fuse = nn.ReLU() + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.fuse(self.conv(x) + self.skip(x)) + + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_groups_out: Optional[int] = None, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + resnet_groups_out = resnet_groups_out or resnet_groups + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + groups_out=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups_out, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups_out, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + timestep=temb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + timestep=temb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + downsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample_type == "conv": + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb) + else: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + timestep=temb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + timestep=temb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = attn(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample: bool = True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + upsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample_type == "conv": + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb) + else: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + zero_SFTs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + zero_SFTs.append( + ZeroSFT( + res_skip_channels, + res_skip_channels, + resnet_in_channels + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.zero_SFTs = nn.ModuleList(zero_SFTs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, attn, zero_SFT in zip(self.resnets, self.attentions, self.zero_SFTs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + if isinstance(res_hidden_states, tuple): + # ZeroSFT + hidden_states = zero_SFT(res_hidden_states, hidden_states) + else: + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states[1]+res_hidden_states[0], + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + timestep=temb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + timestep=temb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + zero_SFTs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + zero_SFTs.append( + ZeroSFT( + res_skip_channels, + res_skip_channels, + resnet_in_channels, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.zero_SFTs = nn.ModuleList(zero_SFTs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + for resnet, zero_SFT in zip(self.resnets, self.zero_SFTs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + if isinstance(res_hidden_states, tuple): + # ZeroSFT + hidden_states = zero_SFT(res_hidden_states, hidden_states) + else: + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states[1]+res_hidden_states[0], + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb) + hidden_states = attn(hidden_states, temb=temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warning( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + *args, + **kwargs, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + hidden_states = self.attentions[0](hidden_states) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + upsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + *args, + **kwargs, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb) + + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim: int = 1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + attention_bias (`bool`, *optional*, defaults to `False`): + Configure if the attention layers should contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to upcast the attention computation to `float32`. + temb_channels (`int`, *optional*, defaults to 768): + The number of channels in the token embedding. + add_self_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to add self-attention to the block. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + group_size (`int`, *optional*, defaults to 32): + The number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/pipelines/sdxl_instantir.py b/pipelines/sdxl_instantir.py new file mode 100644 index 0000000000000000000000000000000000000000..af6b562cead64a00c84cf1afc36de153d6e65511 --- /dev/null +++ b/pipelines/sdxl_instantir.py @@ -0,0 +1,1740 @@ +# Copyright 2024 The InstantX Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from diffusers.utils.import_utils import is_invisible_watermark_available + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers, LCMScheduler +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, + convert_unet_state_dict_to_peft +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker + +from peft import LoraConfig, set_peft_model_state_dict +from module.aggregator import Aggregator + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install diffusers pillow transformers accelerate + >>> import torch + >>> from PIL import Image + + >>> from diffusers import DDPMScheduler + >>> from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler + + >>> from module.ip_adapter.utils import load_adapter_to_pipe + >>> from pipelines.sdxl_instantir import InstantIRPipeline + + >>> # download models under ./models + >>> dcp_adapter = f'./models/adapter.pt' + >>> previewer_lora_path = f'./models' + >>> instantir_path = f'./models/aggregator.pt' + + >>> # load pretrained models + >>> pipe = InstantIRPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16 + ... ) + >>> # load adapter + >>> load_adapter_to_pipe( + ... pipe, + ... dcp_adapter, + ... image_encoder_or_path = 'facebook/dinov2-large', + ... ) + >>> # load previewer lora + >>> pipe.prepare_previewers(previewer_lora_path) + >>> pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler") + >>> lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config) + + >>> # load aggregator weights + >>> pretrained_state_dict = torch.load(instantir_path) + >>> pipe.aggregator.load_state_dict(pretrained_state_dict) + + >>> # send to GPU and fp16 + >>> pipe.to(device="cuda", dtype=torch.float16) + >>> pipe.aggregator.to(device="cuda", dtype=torch.float16) + >>> pipe.enable_model_cpu_offload() + + >>> # load a broken image + >>> low_quality_image = Image.open('path/to/your-image').convert("RGB") + + >>> # restoration + >>> image = pipe( + ... image=low_quality_image, + ... previewer_scheduler=lcm_scheduler, + ... ).images[0] + ``` +""" + +LCM_LORA_MODULES = [ + "to_q", + "to_k", + "to_v", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", +] +PREVIEWER_LORA_MODULES = [ + "to_q", + "to_kv", + "0.to_out", + "attn1.to_k", + "attn1.to_v", + "to_k_ip", + "to_v_ip", + "ln_k_ip.linear", + "ln_v_ip.linear", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", +] + + +def remove_attn2(model): + def recursive_find_module(name, module): + if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return + elif "resnets" in name: return + if hasattr(module, "attn2"): + setattr(module, "attn2", None) + setattr(module, "norm2", None) + return + for sub_name, sub_module in module.named_children(): + recursive_find_module(f"{name}.{sub_name}", sub_module) + + for name, module in model.named_children(): + recursive_find_module(name, module) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class InstantIRPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + StableDiffusionXLLoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14)). + text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]): + Second frozen text-encoder + ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co./laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + tokenizer_2 ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the `unet` during the denoising process. If you set multiple + ControlNets as a list, the outputs from each ControlNet are added together to create one combined + additional conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings should always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to + watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no + watermarker is used. + """ + + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "feature_extractor", + "image_encoder", + ] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + aggregator: Aggregator = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + feature_extractor: CLIPImageProcessor = None, + image_encoder: CLIPVisionModelWithProjection = None, + ): + super().__init__() + + if aggregator is None: + aggregator = Aggregator.from_unet(unet) + remove_attn2(aggregator) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + aggregator=aggregator, + scheduler=scheduler, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=True + ) + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + def prepare_previewers(self, previewer_lora_path: str, use_lcm=False): + if use_lcm: + lora_state_dict, alpha_dict = self.lora_state_dict( + previewer_lora_path, + ) + else: + lora_state_dict, alpha_dict = self.lora_state_dict( + previewer_lora_path, + weight_name="previewer_lora_weights.bin" + ) + unet_state_dict = { + f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + } + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + lora_state_dict = dict() + for k, v in unet_state_dict.items(): + if "ip" in k: + k = k.replace("attn2", "attn2.processor") + lora_state_dict[k] = v + else: + lora_state_dict[k] = v + if alpha_dict: + lora_alpha = next(iter(alpha_dict.values())) + else: + lora_alpha = 1 + logger.info(f"use lora alpha {lora_alpha}") + lora_config = LoraConfig( + r=64, + target_modules=LCM_LORA_MODULES if use_lcm else PREVIEWER_LORA_MODULES, + lora_alpha=lora_alpha, + lora_dropout=0.0, + ) + + adapter_name = "lcm" if use_lcm else "previewer" + self.unet.add_adapter(lora_config, adapter_name) + incompatible_keys = set_peft_model_state_dict(self.unet, lora_state_dict, adapter_name=adapter_name) + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if unexpected_keys: + raise ValueError( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + self.unet.disable_adapters() + + return lora_alpha + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + if isinstance(self.image_encoder, CLIPVisionModelWithProjection): + # CLIP image encoder. + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + else: + # DINO image encoder. + image_embeds = self.image_encoder(image).last_hidden_state + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = self.image_encoder( + torch.zeros_like(image) + ).last_hidden_state + uncond_image_embeds = uncond_image_embeds.repeat_interleave( + num_images_per_prompt, dim=0 + ) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + if isinstance(ip_adapter_image[0], list): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + else: + logger.warning( + f"Got {len(ip_adapter_image)} images for {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + " By default, these images will be sent to each IP-Adapter. If this is not your use-case, please specify `ip_adapter_image` as a list of image-list, with" + f" length equals to the number of IP-Adapters." + ) + ip_adapter_image = [ip_adapter_image] * len(self.unet.encoder_hid_proj.image_projection_layers) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = isinstance(self.image_encoder, CLIPVisionModelWithProjection) and not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * (num_images_per_prompt//single_image_embeds.shape[0]), dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * (num_images_per_prompt//single_negative_image_embeds.shape[0]), dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + image, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + negative_pooled_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.aggregator, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.aggregator, Aggregator) + or is_compiled + and isinstance(self.aggregator._orig_mod, Aggregator) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + if control_guidance_start >= control_guidance_end: + raise ValueError( + f"control guidance start: {control_guidance_start} cannot be larger or equal to control guidance end: {control_guidance_end}." + ) + if control_guidance_start < 0.0: + raise ValueError(f"control guidance start: {control_guidance_start} can't be smaller than 0.") + if control_guidance_end > 1.0: + raise ValueError(f"control guidance end: {control_guidance_end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + return image + + @torch.no_grad() + def init_latents(self, latents, generator, timestep): + noise = torch.randn(latents.shape, generator=generator, device=self.vae.device, dtype=self.vae.dtype, layout=torch.strided) + bsz = latents.shape[0] + print(f"init latent at {timestep}") + timestep = torch.tensor([timestep]*bsz, device=self.vae.device) + # Note that the latents will be scaled aleady by scheduler.add_noise + latents = self.scheduler.add_noise(latents, noise, timestep) + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.FloatTensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 30, + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + save_preview_row: bool = False, + init_latents_with_lq: bool = True, + multistep_restore: bool = False, + adastep_restore: bool = False, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + controlnet_conditioning_scale: float = 1.0, + control_guidance_start: float = 0.0, + control_guidance_end: float = 1.0, + preview_start: float = 0.0, + preview_end: float = 1.0, + original_size: Tuple[int, int] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + previewer_scheduler: KarrasDiffusionSchedulers = None, + reference_latents: Optional[torch.FloatTensor] = None, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single ControlNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` + and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (ฮท) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, pooled text embeddings are generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt + weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input + argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the ControlNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the ControlNet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned containing the output images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + aggregator = self.aggregator._orig_mod if is_compiled_module(self.aggregator) else self.aggregator + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] if ip_adapter_image is not None else [image] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + image, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + negative_pooled_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + if not isinstance(image, PIL.Image.Image): + batch_size = len(image) + else: + batch_size = 1 + prompt = [prompt] * batch_size + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1) + else: + batch_size = prompt_embeds.shape[0] + assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1) + + device = self._execution_device + + # 3.1 Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + + # 3.2 Encode ip_adapter_image + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=aggregator.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + height, width = image.shape[-2:] + if image.shape[1] != 4: + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + if needs_upcasting: + image = image.float() + self.vae.to(dtype=torch.float32) + image = self.vae.encode(image).latent_dist.sample() + image = image * self.vae.config.scaling_factor + if needs_upcasting: + self.vae.to(dtype=torch.float16) + image = image.to(dtype=torch.float16) + else: + height = int(height * self.vae_scale_factor) + width = int(width * self.vae_scale_factor) + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 6. Prepare latent variables + if init_latents_with_lq: + latents = self.init_latents(image, generator, timesteps[0]) + else: + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + previewing = [] + for i in range(len(timesteps)): + keeps = 1.0 - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end) + controlnet_keep.append(keeps) + use_preview = 1.0 - float(i / len(timesteps) < preview_start or (i + 1) / len(timesteps) > preview_end) + previewing.append(use_preview) + if isinstance(controlnet_conditioning_scale, list): + assert len(controlnet_conditioning_scale) == len(timesteps), f"{len(controlnet_conditioning_scale)} controlnet scales do not match number of sampling steps {len(timesteps)}" + else: + controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet_keep) + + # 7.2 Prepare added time ids & embeddings + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + image = torch.cat([image] * 2, dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + is_unet_compiled = is_compiled_module(self.unet) + is_aggregator_compiled = is_compiled_module(self.aggregator) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + previewer_mean = torch.zeros_like(latents) + unet_mean = torch.zeros_like(latents) + preview_factor = torch.ones( + (latents.shape[0], *((1,) * (len(latents.shape) - 1))), dtype=latents.dtype, device=latents.device + ) + + self._num_timesteps = len(timesteps) + preview_row = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_aggregator_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + prev_t = t + unet_model_input = latent_model_input + + added_cond_kwargs = { + "text_embeds": add_text_embeds, + "time_ids": add_time_ids, + "image_embeds": image_embeds + } + aggregator_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + # prepare time_embeds in advance as adapter input + cross_attention_t_emb = self.unet.get_time_embed(sample=latent_model_input, timestep=t) + cross_attention_emb = self.unet.time_embedding(cross_attention_t_emb, timestep_cond) + cross_attention_aug_emb = None + + cross_attention_aug_emb = self.unet.get_aug_embed( + emb=cross_attention_emb, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=added_cond_kwargs + ) + + cross_attention_emb = cross_attention_emb + cross_attention_aug_emb if cross_attention_aug_emb is not None else cross_attention_emb + + if self.unet.time_embed_act is not None: + cross_attention_emb = self.unet.time_embed_act(cross_attention_emb) + + current_cross_attention_kwargs = {"temb": cross_attention_emb} + if cross_attention_kwargs is not None: + for k,v in cross_attention_kwargs.items(): + current_cross_attention_kwargs[k] = v + self._cross_attention_kwargs = current_cross_attention_kwargs + + # adaptive restoration factors + adaRes_scale = preview_factor.to(latent_model_input.dtype).clamp(0.0, controlnet_conditioning_scale[i]) + cond_scale = adaRes_scale * controlnet_keep[i] + cond_scale = torch.cat([cond_scale] * 2) if self.do_classifier_free_guidance else cond_scale + + if (cond_scale>0.1).sum().item() > 0: + if previewing[i] > 0: + # preview with LCM + self.unet.enable_adapters() + preview_noise = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + preview_latent = previewer_scheduler.step( + preview_noise, + t.to(dtype=torch.int64), + # torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, + latent_model_input, # scaled latents here for compatibility + return_dict=False + )[0] + self.unet.disable_adapters() + + if self.do_classifier_free_guidance: + preview_row.append(preview_latent.chunk(2)[1].to('cpu')) + else: + preview_row.append(preview_latent.to('cpu')) + # Prepare 2nd order step. + if multistep_restore and i+1 < len(timesteps): + noise_preview = preview_noise.chunk(2)[1] if self.do_classifier_free_guidance else preview_noise + first_step = self.scheduler.step( + noise_preview, t, latents, + **extra_step_kwargs, return_dict=True, step_forward=False + ) + prev_t = timesteps[i + 1] + unet_model_input = torch.cat([first_step.prev_sample] * 2) if self.do_classifier_free_guidance else first_step.prev_sample + unet_model_input = self.scheduler.scale_model_input(unet_model_input, prev_t, heun_step=True) + + elif reference_latents is not None: + preview_latent = torch.cat([reference_latents] * 2) if self.do_classifier_free_guidance else reference_latents + else: + preview_latent = image + + # Add fresh noise + # preview_noise = torch.randn_like(preview_latent) + # preview_latent = self.scheduler.add_noise(preview_latent, preview_noise, t) + + preview_latent=preview_latent.to(dtype=next(aggregator.parameters()).dtype) + + # Aggregator inference + down_block_res_samples, mid_block_res_sample = aggregator( + image, + prev_t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=preview_latent, + # conditioning_scale=cond_scale, + added_cond_kwargs=aggregator_added_cond_kwargs, + return_dict=False, + ) + + # aggregator features scaling + down_block_res_samples = [sample*cond_scale for sample in down_block_res_samples] + mid_block_res_sample = mid_block_res_sample*cond_scale + + # predict the noise residual + noise_pred = self.unet( + unet_model_input, + prev_t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + unet_step = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True) + latents = unet_step.prev_sample + + # Update adaRes factors + unet_pred_latent = unet_step.pred_original_sample + + # Adaptive restoration. + if adastep_restore: + pred_x0_l2 = ((preview_latent[latents.shape[0]:].float()-unet_pred_latent.float())).pow(2).sum(dim=(1,2,3)) + previewer_l2 = ((preview_latent[latents.shape[0]:].float()-previewer_mean.float())).pow(2).sum(dim=(1,2,3)) + # unet_l2 = ((unet_pred_latent.float()-unet_mean.float())).pow(2).sum(dim=(1,2,3)).sqrt() + # l2_error = (((preview_latent[latents.shape[0]:]-previewer_mean) - (unet_pred_latent-unet_mean))).pow(2).mean(dim=(1,2,3)) + # preview_error = torch.nn.functional.cosine_similarity(preview_latent[latents.shape[0]:].reshape(latents.shape[0], -1), unet_pred_latent.reshape(latents.shape[0],-1)) + previewer_mean = preview_latent[latents.shape[0]:] + unet_mean = unet_pred_latent + preview_factor = (pred_x0_l2 / previewer_l2).reshape(-1, 1, 1, 1) + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + if save_preview_row: + preview_image_row = [] + if needs_upcasting: + self.upcast_vae() + for preview_latents in preview_row: + preview_latents = preview_latents.to(device=self.device, dtype=next(iter(self.vae.post_quant_conv.parameters())).dtype) + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype) + ) + preview_latents = preview_latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + preview_latents = preview_latents / self.vae.config.scaling_factor + + preview_image = self.vae.decode(preview_latents, return_dict=False)[0] + preview_image = self.image_processor.postprocess(preview_image, output_type=output_type) + preview_image_row.append(preview_image) + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + if save_preview_row: + return (image, preview_image_row) + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/pipelines/stage1_sdxl_pipeline.py b/pipelines/stage1_sdxl_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..be429d9e1baa849a6cf57932f665dd6a1756157c --- /dev/null +++ b/pipelines/stage1_sdxl_pipeline.py @@ -0,0 +1,1283 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPTokenizer, + CLIPVisionModelWithProjection, +) + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import ( + FromSingleFileMixin, + IPAdapterMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, +) +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + FusedAttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...models.lora import adjust_lora_scale_text_encoder +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + is_invisible_watermark_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from .pipeline_output import StableDiffusionXLPipelineOutput + + +if is_invisible_watermark_available(): + from .watermark import StableDiffusionXLWatermarker + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline + + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionXLPipeline( + DiffusionPipeline, + StableDiffusionMixin, + FromSingleFileMixin, + StableDiffusionXLLoraLoaderMixin, + TextualInversionLoaderMixin, + IPAdapterMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co./docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co./docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co./laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`): + Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of + `stabilityai/stable-diffusion-xl-base-1-0`. + add_watermarker (`bool`, *optional*): + Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to + watermark output images. If not defined, it will default to True if the package is installed, otherwise no + watermarker will be used. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae" + _optional_components = [ + "tokenizer", + "tokenizer_2", + "text_encoder", + "text_encoder_2", + "image_encoder", + "feature_extractor", + ] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + "add_text_embeds", + "add_time_ids", + "negative_pooled_prompt_embeds", + "negative_add_time_ids", + ] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + self.default_sample_size = self.unet.config.sample_size + + add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() + + if add_watermarker: + self.watermark = StableDiffusionXLWatermarker() + else: + self.watermark = None + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + # textual inversion: process multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = [negative_prompt, negative_prompt_2] + + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + FusedAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding( + self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32 + ) -> torch.FloatTensor: + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + w (`torch.Tensor`): + Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings. + embedding_dim (`int`, *optional*, defaults to 512): + Dimension of the embeddings to generate. + dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): + Data type of the generated embeddings. + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`. + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def denoising_end(self): + return self._denoising_end + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + negative_original_size: Optional[Tuple[int, int]] = None, + negative_crops_coords_top_left: Tuple[int, int] = (0, 0), + negative_target_size: Optional[Tuple[int, int]] = None, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + Anything below 512 pixels won't work well for + [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0) + and checkpoints that are not specifically fine-tuned on low resolutions. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `ฯ†` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). + negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a specific image resolution. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's + micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + To negatively condition the generation process based on a target image resolution. It should be as same + as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more + information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._denoising_end = denoising_end + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + if negative_original_size is not None and negative_target_size is not None: + negative_add_time_ids = self._get_add_time_ids( + negative_original_size, + negative_crops_coords_top_left, + negative_target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + else: + negative_add_time_ids = add_time_ids + + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 8.1 Apply denoising_end + if ( + self.denoising_end is not None + and isinstance(self.denoising_end, float) + and self.denoising_end > 0 + and self.denoising_end < 1 + ): + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (self.denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + # 9. Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, # [B, 77, 2048] + timestep_cond=timestep_cond, # None + cross_attention_kwargs=self.cross_attention_kwargs, # None + added_cond_kwargs=added_cond_kwargs, # {[B, 1280], [B, 6]} + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) + negative_pooled_prompt_embeds = callback_outputs.pop( + "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds + ) + add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) + negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + # make sure the VAE is in float32 mode, as it overflows in float16 + needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast + + if needs_upcasting: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + elif latents.dtype != self.vae.dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + self.vae = self.vae.to(latents.dtype) + + # unscale/denormalize the latents + # denormalize with the mean and std if available and not None + has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None + has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None + if has_latents_mean and has_latents_std: + latents_mean = ( + torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents_std = ( + torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) + ) + latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean + else: + latents = latents / self.vae.config.scaling_factor + + image = self.vae.decode(latents, return_dict=False)[0] + + # cast back to fp16 if needed + if needs_upcasting: + self.vae.to(dtype=torch.float16) + else: + image = latents + + if not output_type == "latent": + # apply watermark if available + if self.watermark is not None: + image = self.watermark.apply_watermark(image) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3c6afbdc100c993a6a21b22d5e976d8239695adb --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +diffusers==0.28.1 +accelerate==0.25.0 +datasets==2.19.1 +einops==0.8.0 +kornia==0.7.2 +numpy==1.26.4 +opencv-python==4.9.0.80 +peft==0.10.0 +pyrallis==0.3.1 +tokenizers==0.15.2 +torch==2.0.1 +torchvision==0.15.2 +transformers==4.36.2 +gradio==4.44.1 \ No newline at end of file diff --git a/schedulers/lcm_single_step_scheduler.py b/schedulers/lcm_single_step_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..a32affdc29c9dd69a35e6220c853ef0d5e98cb74 --- /dev/null +++ b/schedulers/lcm_single_step_scheduler.py @@ -0,0 +1,537 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.utils.torch_utils import randn_tensor +from diffusers.schedulers.scheduling_utils import SchedulerMixin + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class LCMSingleStepSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + denoised: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar( + num_diffusion_timesteps, + max_beta=0.999, + alpha_transform_type="cosine", +): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. + Choose from `cosine` or `exp` + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + if alpha_transform_type == "cosine": + + def alpha_bar_fn(t): + return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 + + elif alpha_transform_type == "exp": + + def alpha_bar_fn(t): + return math.exp(t * -12.0) + + else: + raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr +def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor: + """ + Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) + + + Args: + betas (`torch.FloatTensor`): + the betas that the scheduler is being initialized with. + + Returns: + `torch.FloatTensor`: rescaled betas with zero terminal SNR + """ + # Convert betas to alphas_bar_sqrt + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_bar_sqrt = alphas_cumprod.sqrt() + + # Store old values. + alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() + alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() + + # Shift so the last timestep is zero. + alphas_bar_sqrt -= alphas_bar_sqrt_T + + # Scale so the first timestep is back to the old value. + alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) + + # Convert alphas_bar_sqrt to betas + alphas_bar = alphas_bar_sqrt**2 # Revert sqrt + alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod + alphas = torch.cat([alphas_bar[0:1], alphas]) + betas = 1 - alphas + + return betas + + +class LCMSingleStepScheduler(SchedulerMixin, ConfigMixin): + """ + `LCMSingleStepScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with + non-Markovian guidance. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config + attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be + accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving + functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + beta_start (`float`, defaults to 0.0001): + The starting `beta` value of inference. + beta_end (`float`, defaults to 0.02): + The final `beta` value. + beta_schedule (`str`, defaults to `"linear"`): + The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, *optional*): + Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. + original_inference_steps (`int`, *optional*, defaults to 50): + The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we + will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule. + clip_sample (`bool`, defaults to `True`): + Clip the predicted sample for numerical stability. + clip_sample_range (`float`, defaults to 1.0): + The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, defaults to `True`): + Each diffusion step uses the alphas product value at that step and at the previous one. For the final step + there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the alpha value at step 0. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable + Diffusion. + prediction_type (`str`, defaults to `epsilon`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True`. + timestep_spacing (`str`, defaults to `"leading"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co./papers/2305.08891) for more information. + timestep_scaling (`float`, defaults to 10.0): + The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions + `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation + error at the default of `10.0` is already pretty small). + rescale_betas_zero_snr (`bool`, defaults to `False`): + Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and + dark samples instead of limiting it to samples with medium brightness. Loosely related to + [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). + """ + + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + beta_schedule: str = "scaled_linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + original_inference_steps: int = 50, + clip_sample: bool = False, + clip_sample_range: float = 1.0, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + timestep_spacing: str = "leading", + timestep_scaling: float = 10.0, + rescale_betas_zero_snr: bool = False, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + # Rescale for zero SNR + if rescale_betas_zero_snr: + self.betas = rescale_zero_terminal_snr(self.betas) + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + self._step_index = None + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index + def _init_step_index(self, timestep): + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + + index_candidates = (self.timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + if len(index_candidates) > 1: + step_index = index_candidates[1] + else: + step_index = index_candidates[0] + + self._step_index = step_index.item() + + @property + def step_index(self): + return self._step_index + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + Returns: + `torch.FloatTensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + def set_timesteps( + self, + num_inference_steps: int = None, + device: Union[str, torch.device] = None, + original_inference_steps: Optional[int] = None, + strength: int = 1.0, + timesteps: Optional[list] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + original_inference_steps (`int`, *optional*): + The original number of inference steps, which will be used to generate a linearly-spaced timestep + schedule (which is different from the standard `diffusers` implementation). We will then take + `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as + our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute. + """ + + if num_inference_steps is not None and timesteps is not None: + raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.") + + if timesteps is not None: + for i in range(1, len(timesteps)): + if timesteps[i] >= timesteps[i - 1]: + raise ValueError("`custom_timesteps` must be in descending order.") + + if timesteps[0] >= self.config.num_train_timesteps: + raise ValueError( + f"`timesteps` must start before `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps}." + ) + + timesteps = np.array(timesteps, dtype=np.int64) + else: + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + original_steps = ( + original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps + ) + + if original_steps > self.config.num_train_timesteps: + raise ValueError( + f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + if num_inference_steps > original_steps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:" + f" {original_steps} because the final timestep schedule will be a subset of the" + f" `original_inference_steps`-sized initial timestep schedule." + ) + + # LCM Timesteps Setting + # Currently, only linear spacing is supported. + c = self.config.num_train_timesteps // original_steps + # LCM Training Steps Schedule + lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1 + skipping_step = len(lcm_origin_timesteps) // num_inference_steps + # LCM Inference Steps Schedule + timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps] + + self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long) + + self._step_index = None + + def get_scalings_for_boundary_condition_discrete(self, timestep): + self.sigma_data = 0.5 # Default: 0.5 + scaled_timestep = timestep * self.config.timestep_scaling + + c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5 + return c_skip, c_out + + def append_dims(self, x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + def extract_into_tensor(self, a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.Tensor, + sample: torch.FloatTensor, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[LCMSingleStepSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + # 0. make sure everything is on the same device + alphas_cumprod = self.alphas_cumprod.to(sample.device) + + # 1. compute alphas, betas + if timestep.ndim == 0: + timestep = timestep.unsqueeze(0) + alpha_prod_t = self.extract_into_tensor(alphas_cumprod, timestep, sample.shape) + beta_prod_t = 1 - alpha_prod_t + + # 2. Get scalings for boundary conditions + c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) + c_skip, c_out = [self.append_dims(x, sample.ndim) for x in [c_skip, c_out]] + + # 3. Compute the predicted original sample x_0 based on the model parameterization + if self.config.prediction_type == "epsilon": # noise-prediction + predicted_original_sample = (sample - torch.sqrt(beta_prod_t) * model_output) / torch.sqrt(alpha_prod_t) + elif self.config.prediction_type == "sample": # x-prediction + predicted_original_sample = model_output + elif self.config.prediction_type == "v_prediction": # v-prediction + predicted_original_sample = torch.sqrt(alpha_prod_t) * sample - torch.sqrt(beta_prod_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" + " `v_prediction` for `LCMScheduler`." + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + predicted_original_sample = self._threshold_sample(predicted_original_sample) + elif self.config.clip_sample: + predicted_original_sample = predicted_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. Denoise model output using boundary conditions + denoised = c_out * predicted_original_sample + c_skip * sample + + if not return_dict: + return (denoised, ) + + return LCMSingleStepSchedulerOutput(denoised=denoised) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/train_previewer_lora.py b/train_previewer_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..d1cf004f5df686efb3d0a0ff71abfd67652c4de7 --- /dev/null +++ b/train_previewer_lora.py @@ -0,0 +1,1712 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The LCM team and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import gc +import logging +import pyrallis +import math +import os +import random +import shutil +from contextlib import nullcontext +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from PIL import Image +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from collections import namedtuple +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import ( + AutoTokenizer, + PretrainedConfig, + CLIPImageProcessor, CLIPVisionModelWithProjection, + AutoImageProcessor, AutoModel +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + LCMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import cast_training_params, resolve_interpolation_mode +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from basicsr.utils.degradation_pipeline import RealESRGANDegradation +from utils.train_utils import ( + seperate_ip_params_from_unet, + import_model_class_from_model_name_or_path, + tensor_to_pil, + get_train_dataset, prepare_train_dataset, collate_fn, + encode_prompt, importance_sampling_fn, extract_into_tensor + +) +from data.data_config import DataConfig +from losses.loss_config import LossesConfig +from losses.losses import * + +from module.ip_adapter.resampler import Resampler +from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. + +logger = get_logger(__name__) + + +def prepare_latents(lq, vae, scheduler, generator, timestep): + transform = transforms.Compose([ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ]) + lq_pt = [transform(lq_pil.convert("RGB")) for lq_pil in lq] + img_pt = torch.stack(lq_pt).to(vae.device, dtype=vae.dtype) + img_pt = img_pt * 2.0 - 1.0 + with torch.no_grad(): + latents = vae.encode(img_pt).latent_dist.sample() + latents = latents * vae.config.scaling_factor + noise = torch.randn(latents.shape, generator=generator, device=vae.device, dtype=vae.dtype, layout=torch.strided).to(vae.device) + bsz = latents.shape[0] + print(f"init latent at {timestep}") + timestep = torch.tensor([timestep]*bsz, device=vae.device, dtype=torch.int64) + latents = scheduler.add_noise(latents, noise, timestep) + return latents + + +def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + scheduler, image_encoder, image_processor, + args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False): + logger.info("Running validation... ") + + image_logs = [] + + lq = [Image.open(lq_example) for lq_example in args.validation_image] + + pipe = StableDiffusionXLPipeline( + vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + unet, scheduler, image_encoder, image_processor, + ).to(accelerator.device) + + timesteps = [args.num_train_timesteps - 1] + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + latents = prepare_latents(lq, vae, scheduler, generator, timesteps[-1]) + image = pipe( + prompt=[""]*len(lq), + ip_adapter_image=[lq], + num_inference_steps=1, + timesteps=timesteps, + generator=generator, + guidance_scale=1.0, + height=args.resolution, + width=args.resolution, + latents=latents, + ).images + + if log_local: + # for i, img in enumerate(tensor_to_pil(lq_img)): + # img.save(f"./lq_{i}.png") + # for i, img in enumerate(tensor_to_pil(gt_img)): + # img.save(f"./gt_{i}.png") + for i, img in enumerate(image): + img.save(f"./lq_IPA_{i}.png") + return + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + images = [np.asarray(pil_img) for pil_img in image] + images = np.stack(images, axis=0) + if lq_img is not None and gt_img is not None: + input_lq = lq_img.detach().cpu() + input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1)) + input_gt = gt_img.detach().cpu() + input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1)) + tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW") + tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW") + tracker.writer.add_images("rec", images, step, dataformats="NHWC") + elif tracker.name == "wandb": + raise NotImplementedError("Wandb logging not implemented for validation.") + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +class DDIMSolver: + def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): + # DDIM sampling parameters + step_ratio = timesteps // ddim_timesteps + + self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 + self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] + self.ddim_alpha_cumprods_prev = np.asarray( + [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() + ) + # convert to torch tensors + self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() + self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) + self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) + + def to(self, device): + self.ddim_timesteps = self.ddim_timesteps.to(device) + self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) + self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) + return self + + def ddim_step(self, pred_x0, pred_noise, timestep_index): + alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) + dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise + x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt + return x_prev + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +# From LCMScheduler.get_scalings_for_boundary_condition_discrete +def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): + scaled_timestep = timestep_scaling * timestep + c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) + c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 + return c_skip, c_out + + +# Compare LCMScheduler.step, Step 4 +def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif prediction_type == "sample": + pred_x_0 = model_output + elif prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_x_0 + + +# Based on step 4 in DDIMScheduler.step +def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas): + alphas = extract_into_tensor(alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(sigmas, timesteps, sample.shape) + if prediction_type == "epsilon": + pred_epsilon = model_output + elif prediction_type == "sample": + pred_epsilon = (sample - alphas * model_output) / sigmas + elif prediction_type == "v_prediction": + pred_epsilon = alphas * model_output + sigmas * sample + else: + raise ValueError( + f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`" + f" are supported." + ) + + return pred_epsilon + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_lcm_lora_path", + type=str, + default=None, + help="Path to LCM lora or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--feature_extractor_path", + type=str, + default=None, + help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_adapter_model_path", + type=str, + default=None, + help="Path to IP-Adapter models or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--adapter_tokens", + type=int, + default=64, + help="Number of tokens to use in IP-adapter cross attention mechanism.", + ) + parser.add_argument( + "--use_clip_encoder", + action="store_true", + help="Whether or not to use DINO as image encoder, else CLIP encoder.", + ) + parser.add_argument( + "--image_encoder_hidden_feature", + action="store_true", + help="Whether or not to use the penultimate hidden states as image embeddings.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="lcm-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=4000, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=5, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--save_only_adapter", + action="store_true", + help="Only save extra adapter to save space.", + ) + # ----Image Processing---- + parser.add_argument( + "--data_config_path", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co./docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--text_drop_rate", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--image_drop_rate", + type=float, + default=0, + help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).", + ) + parser.add_argument( + "--cond_drop_rate", + type=float, + default=0, + help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).", + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--encode_batch_size", + type=int, + default=8, + help="Batch size to use for VAE encoding of the images for efficient processing.", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + # ----Diffusion Training Arguments---- + # ----Latent Consistency Distillation (LCD) Specific Arguments---- + parser.add_argument( + "--w_min", + type=float, + default=3.0, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=15.0, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--num_train_timesteps", + type=int, + default=1000, + help="The number of timesteps to use for DDIM sampling.", + ) + parser.add_argument( + "--num_ddim_timesteps", + type=int, + default=50, + help="The number of timesteps to use for DDIM sampling.", + ) + parser.add_argument( + "--losses_config_path", + type=str, + default='config_files/losses.yaml', + required=True, + help=("A yaml file containing losses to use and their weights."), + ) + parser.add_argument( + "--loss_type", + type=str, + default="l2", + choices=["l2", "huber"], + help="The type of loss to use for the LCD loss.", + ) + parser.add_argument( + "--huber_c", + type=float, + default=0.001, + help="The huber loss parameter. Only used if `--loss_type=huber`.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=64, + help="The rank of the LoRA projection matrix.", + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + parser.add_argument( + "--timestep_scaling_factor", + type=float, + default=10.0, + help=( + "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The" + " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically" + " suffice." + ), + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=3000, + help="Run validation every X steps.", + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--sanity_check", + action="store_true", + help=( + "sanity check" + ), + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="trian", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co./docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + return args + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation. + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # 1. Create the noise scheduler and the desired noise schedule. + noise_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.teacher_revision + ) + noise_scheduler.config.num_train_timesteps = args.num_train_timesteps + lcm_scheduler = LCMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # Initialize the DDIM ODE solver for distillation. + solver = DDIMSolver( + noise_scheduler.alphas_cumprod.numpy(), + timesteps=noise_scheduler.config.num_train_timesteps, + ddim_timesteps=args.num_ddim_timesteps, + ) + + # 2. Load tokenizers from SDXL checkpoint. + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SDXL checkpoint. + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.teacher_revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.teacher_revision, subfolder="text_encoder_2" + ) + + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.teacher_revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.teacher_revision + ) + + if args.use_clip_encoder: + image_processor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path) + else: + image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path) + image_encoder = AutoModel.from_pretrained(args.feature_extractor_path) + + # 4. Load VAE from SDXL checkpoint (or more stable VAE) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.teacher_revision, + ) + + # 7. Create online student U-Net. + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.teacher_revision + ) + + # Resampler for project model in IP-Adapter + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=args.adapter_tokens, + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + ff_mult=4 + ) + + # Load the same adapter in both unet. + init_adapter_in_unet( + unet, + image_proj_model, + os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'), + adapter_tokens=args.adapter_tokens, + ) + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + if unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + if args.pretrained_lcm_lora_path is not None: + lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path) + unet_state_dict = { + f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + } + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + lora_state_dict = dict() + for k, v in unet_state_dict.items(): + if "ip" in k: + k = k.replace("attn2", "attn2.processor") + lora_state_dict[k] = v + else: + lora_state_dict[k] = v + if alpha_dict: + args.lora_alpha = next(iter(alpha_dict.values())) + else: + args.lora_alpha = 1 + # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ + "to_q", + "to_kv", + "0.to_out", + "attn1.to_k", + "attn1.to_v", + "to_k_ip", + "to_v_ip", + "ln_k_ip.linear", + "ln_v_ip.linear", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + + # Legacy + # for k, v in lcm_pipe.unet.state_dict().items(): + # if "lora" in k or "base_layer" in k: + # lcm_dict[k.replace("default_0", "default")] = v + + unet.add_adapter(lora_config) + if args.pretrained_lcm_lora_path is not None: + incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # 6. Freeze unet, vae, text_encoders. + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + image_encoder.requires_grad_(False) + unet.requires_grad_(False) + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if args.save_only_adapter: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(model, type(unwrap_model(unet))): # save adapter only + unet_ = unwrap_model(model) + # also save the checkpoints in native `diffusers` format so that it can be easily + # be independently loaded via `load_lora_weights()`. + state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_)) + StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict, safe_serialization=False) + + weights.pop() + + def load_model_hook(models, input_dir): + + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, type(unwrap_model(unet))): + unet_ = unwrap_model(model) + lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir) + unet_state_dict = { + f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + } + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + lora_state_dict = dict() + for k, v in unet_state_dict.items(): + if "ip" in k: + k = k.replace("attn2", "attn2.processor") + lora_state_dict[k] = v + else: + lora_state_dict[k] = v + incompatible_keys = set_peft_model_state_dict(unet_, lora_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 11. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co./docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + vae.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 12. Optimizer creation + lora_params, non_lora_params = seperate_lora_params_from_unet(unet) + params_to_optimize = lora_params + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # 13. Dataset creation and data processing + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + datasets = [] + datasets_name = [] + datasets_weights = [] + deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution) + if args.data_config_path is not None: + data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r")) + for single_dataset in data_config.datasets: + datasets_weights.append(single_dataset.dataset_weight) + datasets_name.append(single_dataset.dataset_folder) + dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder) + image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator) + image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline) + datasets.append(image_dataset) + # TODO: Validation dataset + if data_config.val_dataset is not None: + val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator) + logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}") + + # Mix training datasets. + sampler_train = None + if len(datasets) == 1: + train_dataset = datasets[0] + else: + # Weighted each dataset + train_dataset = torch.utils.data.ConcatDataset(datasets) + dataset_weights = [] + for single_dataset, single_weight in zip(datasets, datasets_weights): + dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset)) + sampler_train = torch.utils.data.WeightedRandomSampler( + weights=dataset_weights, + num_samples=len(dataset_weights) + ) + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler_train, + shuffle=True if sampler_train is None else False, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # 14. Embeddings for the UNet. + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True): + def compute_time_ids(original_size, crops_coords_top_left): + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train) + add_text_embeds = pooled_prompt_embeds + + add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)]) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + + compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers) + + # Move pixels into latents. + @torch.no_grad() + def convert_to_latent(pixels): + model_input = vae.encode(pixels).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + return model_input + + # 15. LR Scheduler creation + # Scheduler and math around the number of training steps. + # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation. + num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes + if args.max_train_steps is None: + len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes) + num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps) + num_training_steps_for_scheduler = ( + args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes + ) + else: + num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + # only upcast trainable parameters (LoRA) into fp32 + cast_training_params(unet, dtype=torch.float32) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=num_warmup_steps_for_scheduler, + num_training_steps=num_training_steps_for_scheduler, + ) + + # 16. Prepare for training + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # 8. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is None: + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + for p in non_lora_params: + p.data = p.data.to(dtype=weight_dtype) + for p in lora_params: + p.requires_grad_(True) + unet.to(accelerator.device) + + # Also move the alpha and sigma noise schedules to accelerator.device. + alpha_schedule = alpha_schedule.to(accelerator.device) + sigma_schedule = sigma_schedule.to(accelerator.device) + solver = solver.to(accelerator.device) + + # Instantiate Loss. + losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r")) + lcm_losses = list() + for loss_config in losses_configs.lcm_losses: + logger.info(f"Loading lcm loss: {loss_config.name}") + loss = namedtuple("loss", ["loss", "weight"]) + loss_class = eval(loss_config.name) + lcm_losses.append(loss(loss_class( + visualize_every_k=loss_config.visualize_every_k, + dtype=weight_dtype, + accelerator=accelerator, + dino_model=image_encoder, + dino_preprocess=image_processor, + huber_c=args.huber_c, + **loss_config.init_params), weight=loss_config.weight)) + + # Final check. + for n, p in unet.named_parameters(): + if p.requires_grad: + assert "lora" in n, n + assert p.dtype == torch.float32, n + else: + assert "lora" not in n, f"{n}" + assert p.dtype == weight_dtype, n + if args.sanity_check: + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + + # Check input data + batch = next(iter(train_dataloader)) + lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"])) + out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, + lcm_scheduler, image_encoder, image_processor, + args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True) + exit() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes: + logger.warning( + f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match " + f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. " + f"This inconsistency may result in the learning rate scheduler not functioning properly." + ) + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # 17. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + unet.train() + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + total_loss = torch.tensor(0.0) + bsz = batch["images"].shape[0] + + # Drop conditions. + rand_tensor = torch.rand(bsz) + drop_image_idx = rand_tensor < args.image_drop_rate + drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate) + drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate) + drop_image_idx = drop_image_idx | drop_both_idx + drop_text_idx = drop_text_idx | drop_both_idx + + with torch.no_grad(): + lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"])) + lq_pt = image_processor( + images=lq_img*0.5+0.5, + do_rescale=False, return_tensors="pt" + ).pixel_values + image_embeds = prepare_training_image_embeds( + image_encoder, image_processor, + ip_adapter_image=lq_pt, ip_adapter_image_embeds=None, + device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature, + idx_to_replace=drop_image_idx + ) + uncond_image_embeds = prepare_training_image_embeds( + image_encoder, image_processor, + ip_adapter_image=lq_pt, ip_adapter_image_embeds=None, + device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature, + idx_to_replace=torch.ones_like(drop_image_idx) + ) + # 1. Load and process the image and text conditioning + text, orig_size, crop_coords = ( + batch["text"], + batch["original_sizes"], + batch["crop_top_lefts"], + ) + + encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) + uncond_encoded_text = compute_embeddings_fn([""]*len(text), orig_size, crop_coords) + + # encode pixel values with batch size of at most args.vae_encode_batch_size + gt_img = gt_img.to(dtype=vae.dtype) + latents = [] + for i in range(0, gt_img.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(gt_img[i : i + args.vae_encode_batch_size]).latent_dist.sample()) + latents = torch.cat(latents, dim=0) + # latents = convert_to_latent(gt_img) + + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias. + # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...] + bsz = latents.shape[0] + topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps + index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long() + start_timesteps = solver.ddim_timesteps[index] + timesteps = start_timesteps - topk + timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps) + + # 3. Get boundary scalings for start_timesteps and (end) timesteps. + c_skip_start, c_out_start = scalings_for_boundary_conditions( + start_timesteps, timestep_scaling=args.timestep_scaling_factor + ) + c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]] + c_skip, c_out = scalings_for_boundary_conditions( + timesteps, timestep_scaling=args.timestep_scaling_factor + ) + c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]] + + # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each + # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1] + noise = torch.randn_like(latents) + noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps) + + # 5. Sample a random guidance scale w from U[w_min, w_max] + # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + w = w.to(device=latents.device, dtype=latents.dtype) + + # 6. Prepare prompt embeds and unet_added_conditions + prompt_embeds = encoded_text.pop("prompt_embeds") + encoded_text["image_embeds"] = image_embeds + uncond_prompt_embeds = uncond_encoded_text.pop("prompt_embeds") + uncond_encoded_text["image_embeds"] = image_embeds + + # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps) + noise_pred = unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds, + added_cond_kwargs=uncond_encoded_text, + ).sample + pred_x_0 = get_predicted_original_sample( + noise_pred, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0 + + # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the + # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these + # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE + # solver timestep. + + # With the adapters disabled, the `unet` is the regular teacher model. + accelerator.unwrap_model(unet).disable_adapters() + with torch.no_grad(): + + # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c + teacher_added_cond = dict() + for k,v in encoded_text.items(): + if isinstance(v, torch.Tensor): + teacher_added_cond[k] = v.to(weight_dtype) + else: + teacher_image_embeds = [] + for img_emb in v: + teacher_image_embeds.append(img_emb.to(weight_dtype)) + teacher_added_cond[k] = teacher_image_embeds + cond_teacher_output = unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=teacher_added_cond, + ).sample + cond_pred_x0 = get_predicted_original_sample( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + cond_pred_noise = get_predicted_noise( + cond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0 + teacher_added_uncond = dict() + uncond_encoded_text["image_embeds"] = uncond_image_embeds + for k,v in uncond_encoded_text.items(): + if isinstance(v, torch.Tensor): + teacher_added_uncond[k] = v.to(weight_dtype) + else: + teacher_uncond_image_embeds = [] + for img_emb in v: + teacher_uncond_image_embeds.append(img_emb.to(weight_dtype)) + teacher_added_uncond[k] = teacher_uncond_image_embeds + uncond_teacher_output = unet( + noisy_model_input, + start_timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs=teacher_added_uncond, + ).sample + uncond_pred_x0 = get_predicted_original_sample( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + uncond_pred_noise = get_predicted_noise( + uncond_teacher_output, + start_timesteps, + noisy_model_input, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + + # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise) + # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation + pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0) + pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise) + # 4. Run one step of the ODE solver to estimate the next point x_prev on the + # augmented PF-ODE trajectory (solving backward in time) + # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0. + x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(weight_dtype) + + # re-enable unet adapters to turn the `unet` into a student unet. + accelerator.unwrap_model(unet).enable_adapters() + + # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps) + # Note that we do not use a separate target network for LCM-LoRA distillation. + with torch.no_grad(): + uncond_encoded_text["image_embeds"] = image_embeds + target_added_cond = dict() + for k,v in uncond_encoded_text.items(): + if isinstance(v, torch.Tensor): + target_added_cond[k] = v.to(weight_dtype) + else: + target_image_embeds = [] + for img_emb in v: + target_image_embeds.append(img_emb.to(weight_dtype)) + target_added_cond[k] = target_image_embeds + target_noise_pred = unet( + x_prev, + timesteps, + encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype), + added_cond_kwargs=target_added_cond, + ).sample + pred_x_0 = get_predicted_original_sample( + target_noise_pred, + timesteps, + x_prev, + noise_scheduler.config.prediction_type, + alpha_schedule, + sigma_schedule, + ) + target = c_skip * x_prev + c_out * pred_x_0 + + # 10. Calculate loss + lcm_loss_arguments = { + "target": target.float(), + "predict": model_pred.float(), + } + loss_dict = dict() + # total_loss = total_loss + torch.mean( + # torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c + # ) + # loss_dict["L2Loss"] = total_loss.item() + for loss_config in lcm_losses: + if loss_config.loss.__class__.__name__=="DINOLoss": + with torch.no_grad(): + pixel_target = [] + latent_target = target.to(dtype=vae.dtype) + for i in range(0, latent_target.shape[0], args.vae_encode_batch_size): + pixel_target.append( + vae.decode( + latent_target[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor, + return_dict=False + )[0] + ) + pixel_target = torch.cat(pixel_target, dim=0) + pixel_pred = [] + latent_pred = model_pred.to(dtype=vae.dtype) + for i in range(0, latent_pred.shape[0], args.vae_encode_batch_size): + pixel_pred.append( + vae.decode( + latent_pred[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor, + return_dict=False + )[0] + ) + pixel_pred = torch.cat(pixel_pred, dim=0) + dino_loss_arguments = { + "target": pixel_target, + "predict": pixel_pred, + } + non_weighted_loss = loss_config.loss(**dino_loss_arguments, accelerator=accelerator) + loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item() + total_loss = total_loss + non_weighted_loss * loss_config.weight + else: + non_weighted_loss = loss_config.loss(**lcm_loss_arguments, accelerator=accelerator) + total_loss = total_loss + non_weighted_loss * loss_config.weight + loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item() + + # 11. Backpropagate on the online student model (`unet`) (only LoRA) + accelerator.backward(total_loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, + lcm_scheduler, image_encoder, image_processor, + args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False, log_local=False) + + logs = dict() + # logs.update({"loss": loss.detach().item()}) + logs.update(loss_dict) + logs.update({"lr": lr_scheduler.get_last_lr()[0]}) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict) + + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + del unet + torch.cuda.empty_cache() + + # Final inference. + if args.validation_steps is not None: + log_validation(unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + lcm_scheduler, image_encoder=None, image_processor=None, + args=args, accelerator=accelerator, weight_dtype=weight_dtype, step=0, is_final_validation=False, log_local=True) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/train_previewer_lora.sh b/train_previewer_lora.sh new file mode 100644 index 0000000000000000000000000000000000000000..29db5164dbb77f8025c67bd0f0ca379c793d2212 --- /dev/null +++ b/train_previewer_lora.sh @@ -0,0 +1,24 @@ +# After DCP training, distill the Previewer with DCP in `train_previewer_lora.py`: +accelerate launch --num_processes train_previewer_lora.py \ + --output_dir \ + --train_data_dir \ + --logging_dir \ + --pretrained_model_name_or_path \ + --feature_extractor_path \ + --pretrained_adapter_model_path \ + --losses_config_path config_files/losses.yaml \ + --data_config_path config_files/IR_dataset.yaml \ + --save_only_adapter \ + --gradient_checkpointing \ + --num_train_timesteps 1000 \ + --num_ddim_timesteps 50 \ + --lora_alpha 1 \ + --mixed_precision fp16 \ + --train_batch_size 32 \ + --vae_encode_batch_size 16 \ + --gradient_accumulation_steps 1 \ + --learning_rate 1e-4 \ + --lr_warmup_steps 1000 \ + --lr_scheduler cosine \ + --lr_num_cycles 1 \ + --resume_from_checkpoint latest \ No newline at end of file diff --git a/train_stage1_adapter.py b/train_stage1_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..ec2ee3326ec2c6b85c0e164cc3c1282b76d91143 --- /dev/null +++ b/train_stage1_adapter.py @@ -0,0 +1,1259 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import time +import gc +import logging +import math +import os +import random +import jsonlines +import functools +import shutil +import pyrallis +import itertools +from pathlib import Path +from collections import namedtuple, OrderedDict + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from datasets import load_dataset +from packaging import version +from PIL import Image +from data.data_config import DataConfig +from basicsr.utils.degradation_pipeline import RealESRGANDegradation +from losses.loss_config import LossesConfig +from losses.losses import * +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import ( + AutoTokenizer, + PretrainedConfig, + CLIPImageProcessor, CLIPVisionModelWithProjection, + AutoImageProcessor, AutoModel) + +import diffusers +from diffusers import ( + AutoencoderKL, + AutoencoderTiny, + DDPMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available, make_image_grid +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler +from utils.train_utils import ( + seperate_ip_params_from_unet, + import_model_class_from_model_name_or_path, + tensor_to_pil, + get_train_dataset, prepare_train_dataset, collate_fn, + encode_prompt, importance_sampling_fn, extract_into_tensor +) +from module.ip_adapter.resampler import Resampler +from module.ip_adapter.attention_processor import init_attn_proc +from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds + + +if is_wandb_available(): + import wandb + + +logger = get_logger(__name__) + + +def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False): + logger.info("Running validation... ") + + image_logs = [] + + lq = [Image.open(lq_example) for lq_example in args.validation_image] + + pipe = StableDiffusionXLPipeline( + vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + unet, scheduler, image_encoder, image_processor, + ).to(accelerator.device) + + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + image = pipe( + prompt=[""]*len(lq), + ip_adapter_image=[lq], + num_inference_steps=20, + generator=generator, + guidance_scale=5.0, + height=args.resolution, + width=args.resolution, + ).images + + if log_local: + for i, img in enumerate(tensor_to_pil(lq_img)): + img.save(f"./lq_{i}.png") + for i, img in enumerate(tensor_to_pil(gt_img)): + img.save(f"./gt_{i}.png") + for i, img in enumerate(image): + img.save(f"./lq_IPA_{i}.png") + return + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + images = [np.asarray(pil_img) for pil_img in image] + images = np.stack(images, axis=0) + if lq_img is not None and gt_img is not None: + input_lq = lq_img.detach().cpu() + input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1)) + input_gt = gt_img.detach().cpu() + input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1)) + tracker.writer.add_images("lq", input_lq[0], step, dataformats="CHW") + tracker.writer.add_images("gt", input_gt[0], step, dataformats="CHW") + tracker.writer.add_images("rec", images, step, dataformats="NHWC") + elif tracker.name == "wandb": + raise NotImplementedError("Wandb logging not implemented for validation.") + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="InstantIR stage-1 training.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--feature_extractor_path", + type=str, + default=None, + help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_adapter_model_path", + type=str, + default=None, + help="Path to IP-Adapter models or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--adapter_tokens", + type=int, + default=64, + help="Number of tokens to use in IP-adapter cross attention mechanism.", + ) + parser.add_argument( + "--use_clip_encoder", + action="store_true", + help="Whether or not to use DINO as image encoder, else CLIP encoder.", + ) + parser.add_argument( + "--image_encoder_hidden_feature", + action="store_true", + help="Whether or not to use the penultimate hidden states as image embeddings.", + ) + parser.add_argument( + "--losses_config_path", + type=str, + required=True, + default='config_files/losses.yaml' + help=("A yaml file containing losses to use and their weights."), + ) + parser.add_argument( + "--data_config_path", + type=str, + default='config_files/IR_dataset.yaml', + help=("A folder containing the training data. "), + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="stage1_model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=2000, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co./docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=5, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--save_only_adapter", + action="store_true", + help="Only save extra adapter to save space.", + ) + parser.add_argument( + "--importance_sampling", + action="store_true", + help="Whether or not to use importance sampling.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that ๐Ÿค— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co./docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--text_drop_rate", + type=float, + default=0.05, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--image_drop_rate", + type=float, + default=0.05, + help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).", + ) + parser.add_argument( + "--cond_drop_rate", + type=float, + default=0.05, + help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).", + ) + parser.add_argument( + "--sanity_check", + action="store_true", + help=( + "sanity check" + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=3000, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="instantir_stage1", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co./docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + # if args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None: + # raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.text_drop_rate < 0 or args.text_drop_rate > 1: + raise ValueError("`--text_drop_rate` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + # kwargs_handlers=[kwargs], + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation. + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + # Importance sampling. + list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64') + prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5) + importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps + importance_ratio = torch.from_numpy(importance_ratio.copy()).float() + + # Load the tokenizers + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_2 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # Text encoder and image encoder. + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + text_encoder = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_2 = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + if args.use_clip_encoder: + image_processor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path) + else: + image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path) + image_encoder = AutoModel.from_pretrained(args.feature_extractor_path) + + # VAE. + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + + # UNet. + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant + ) + + pipe = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unet, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + vae=vae, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + variant=args.variant + ) + + # Resampler for project model in IP-Adapter + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=args.adapter_tokens, + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + ff_mult=4 + ) + + init_adapter_in_unet( + unet, + image_proj_model, + os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'), + adapter_tokens=args.adapter_tokens, + ) + + # Initialize training state. + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + text_encoder_2.requires_grad_(False) + unet.requires_grad_(False) + image_encoder.requires_grad_(False) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if args.save_only_adapter: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(model, type(unwrap_model(unet))): # save adapter only + adapter_state_dict = OrderedDict() + adapter_state_dict["image_proj"] = model.encoder_hid_proj.image_projection_layers[0].state_dict() + adapter_state_dict["ip_adapter"] = torch.nn.ModuleList(model.attn_processors.values()).state_dict() + torch.save(adapter_state_dict, os.path.join(output_dir, "adapter_ckpt.pt")) + + weights.pop() + + def load_model_hook(models, input_dir): + + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(unet))): + adapter_state_dict = torch.load(os.path.join(input_dir, "adapter_ckpt.pt"), map_location="cpu") + if list(adapter_state_dict.keys()) != ["image_proj", "ip_adapter"]: + from module.ip_adapter.utils import revise_state_dict + adapter_state_dict = revise_state_dict(adapter_state_dict) + model.encoder_hid_proj.image_projection_layers[0].load_state_dict(adapter_state_dict["image_proj"], strict=True) + missing, unexpected = torch.nn.ModuleList(model.attn_processors.values()).load_state_dict(adapter_state_dict["ip_adapter"], strict=False) + if len(unexpected) > 0: + raise ValueError(f"Unexpected keys: {unexpected}") + if len(missing) > 0: + for mk in missing: + if "ln" not in mk: + raise ValueError(f"Missing keys: {missing}") + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co./docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + vae.enable_gradient_checkpointing() + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation. + ip_params, non_ip_params = seperate_ip_params_from_unet(unet) + params_to_optimize = ip_params + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Instantiate Loss. + losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r")) + diffusion_losses = list() + for loss_config in losses_configs.diffusion_losses: + logger.info(f"Loading diffusion loss: {loss_config.name}") + loss = namedtuple("loss", ["loss", "weight"]) + loss_class = eval(loss_config.name) + diffusion_losses.append(loss(loss_class(visualize_every_k=loss_config.visualize_every_k, + dtype=weight_dtype, + accelerator=accelerator, + **loss_config.init_params), weight=loss_config.weight)) + + # SDXL additional condition that will be added to time embedding. + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + # Text prompt embeddings. + @torch.no_grad() + def compute_embeddings(batch, text_encoders, tokenizers, drop_idx=None, is_train=True): + prompt_batch = batch[args.caption_column] + if drop_idx is not None: + for i in range(len(prompt_batch)): + prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i] + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, is_train + ) + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + ) + + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + sdxl_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} + + return prompt_embeds, sdxl_added_cond_kwargs + + # Move pixels into latents. + @torch.no_grad() + def convert_to_latent(pixels): + model_input = vae.encode(pixels).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + return model_input + + # Datasets and other data moduels. + deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution) + compute_embeddings_fn = functools.partial( + compute_embeddings, + text_encoders=[text_encoder, text_encoder_2], + tokenizers=[tokenizer, tokenizer_2], + is_train=True, + ) + + datasets = [] + datasets_name = [] + datasets_weights = [] + if args.data_config_path is not None: + data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r")) + for single_dataset in data_config.datasets: + datasets_weights.append(single_dataset.dataset_weight) + datasets_name.append(single_dataset.dataset_folder) + dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder) + image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator) + image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline) + datasets.append(image_dataset) + # TODO: Validation dataset + if data_config.val_dataset is not None: + val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator) + logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}") + + # Mix training datasets. + sampler_train = None + if len(datasets) == 1: + train_dataset = datasets[0] + else: + # Weighted each dataset + train_dataset = torch.utils.data.ConcatDataset(datasets) + dataset_weights = [] + for single_dataset, single_weight in zip(datasets, datasets_weights): + dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset)) + sampler_train = torch.utils.data.WeightedRandomSampler( + weights=dataset_weights, + num_samples=len(dataset_weights) + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + sampler=sampler_train, + shuffle=True if sampler_train is None else False, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # Move vae, unet and text_encoder to device and cast to weight_dtype + if args.pretrained_vae_model_name_or_path is None: + # The VAE is fp32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + importance_ratio = importance_ratio.to(accelerator.device) + for non_ip_param in non_ip_params: + non_ip_param.data = non_ip_param.data.to(dtype=weight_dtype) + for ip_param in ip_params: + ip_param.requires_grad_(True) + unet.to(accelerator.device) + + # Final check. + for n, p in unet.named_parameters(): + if p.requires_grad: assert p.dtype == torch.float32, n + else: assert p.dtype == weight_dtype, n + if args.sanity_check: + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + + # Check input data + batch = next(iter(train_dataloader)) + lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"])) + images_log = log_validation( + unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + noise_scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True + ) + exit() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + trainable_models = [unet] + + if args.gradient_checkpointing: + checkpoint_models = [] + else: + checkpoint_models = [] + + image_logs = None + tic = time.time() + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + toc = time.time() + io_time = toc - tic + tic = toc + for model in trainable_models + checkpoint_models: + model.train() + with accelerator.accumulate(*trainable_models): + loss = torch.tensor(0.0) + + # Drop conditions. + rand_tensor = torch.rand(batch["images"].shape[0]) + drop_image_idx = rand_tensor < args.image_drop_rate + drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate) + drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate) + drop_image_idx = drop_image_idx | drop_both_idx + drop_text_idx = drop_text_idx | drop_both_idx + + # Get LQ embeddings + with torch.no_grad(): + lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"])) + lq_pt = image_processor( + images=lq_img*0.5+0.5, + do_rescale=False, return_tensors="pt" + ).pixel_values + image_embeds = prepare_training_image_embeds( + image_encoder, image_processor, + ip_adapter_image=lq_pt, ip_adapter_image_embeds=None, + device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature, + idx_to_replace=drop_image_idx + ) + + # Process text inputs. + prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx) + added_conditions["image_embeds"] = image_embeds + + # Move inputs to latent space. + gt_img = gt_img.to(dtype=vae.dtype) + model_input = convert_to_latent(gt_img) + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + + # Sample noise that we'll add to the latents. + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + + # Sample a random timestep for each image. + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None + + toc = time.time() + prepare_time = toc - tic + tic = time.time() + + model_pred = unet( + noisy_model_input, timesteps, + encoder_hidden_states=prompt_embeds_input, + added_cond_kwargs=added_conditions, + return_dict=False + )[0] + + diffusion_loss_arguments = { + "target": noise, + "predict": model_pred, + "prompt_embeddings_input": prompt_embeds_input, + "timesteps": timesteps, + "weights": loss_weights, + } + + loss_dict = dict() + for loss_config in diffusion_losses: + non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator) + loss = loss + non_weighted_loss * loss_config.weight + loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item() + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + toc = time.time() + forward_time = toc - tic + tic = toc + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + image_logs = log_validation(unwrap_model(unet), vae, + text_encoder, text_encoder_2, tokenizer, tokenizer_2, + noise_scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False) + + logs = {} + logs.update(loss_dict) + logs.update({ + "lr": lr_scheduler.get_last_lr()[0], + "io_time": io_time, + "prepare_time": prepare_time, + "forward_time": forward_time + }) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + tic = time.time() + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + accelerator.save_state(os.path.join(args.output_dir, "last"), safe_serialization=False) + # Run a final round of validation. + # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`. + image_logs = None + if args.validation_image is not None: + image_logs = log_validation( + unwrap_model(unet), vae, + text_encoder, text_encoder_2, tokenizer, tokenizer_2, + noise_scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, global_step, + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/train_stage1_adapter.sh b/train_stage1_adapter.sh new file mode 100644 index 0000000000000000000000000000000000000000..71fee17856de283ec47183a61bfb6d5b4ae9ec72 --- /dev/null +++ b/train_stage1_adapter.sh @@ -0,0 +1,17 @@ +# Stage 1: training lq adapter +accelerate launch --num_processes train_stage1_adapter.py \ + --output_dir \ + --train_data_dir \ + --logging_dir \ + --pretrained_model_name_or_path \ + --feature_extractor_path \ + --save_only_adapter \ + --gradient_checkpointing \ + --mixed_precision fp16 \ + --train_batch_size 96 \ + --gradient_accumulation_steps 1 \ + --learning_rate 1e-4 \ + --lr_warmup_steps 1000 \ + --lr_scheduler cosine \ + --lr_num_cycles 1 \ + --resume_from_checkpoint latest \ No newline at end of file diff --git a/train_stage2_aggregator.py b/train_stage2_aggregator.py new file mode 100644 index 0000000000000000000000000000000000000000..31c617d8ac295e5a7bca637cc029a5af6c85be27 --- /dev/null +++ b/train_stage2_aggregator.py @@ -0,0 +1,1698 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import os +import argparse +import time +import gc +import logging +import math +import copy +import random +import yaml +import functools +import shutil +import pyrallis +from pathlib import Path +from collections import namedtuple, OrderedDict + +import accelerate +import numpy as np +import torch +from safetensors import safe_open +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from datasets import load_dataset +from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from data.data_config import DataConfig +from basicsr.utils.degradation_pipeline import RealESRGANDegradation +from losses.loss_config import LossesConfig +from losses.losses import * +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm +from transformers import ( + AutoTokenizer, + PretrainedConfig, + CLIPImageProcessor, CLIPVisionModelWithProjection, + AutoImageProcessor, AutoModel +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import ( + check_min_version, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module + +from module.aggregator import Aggregator +from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler +from module.ip_adapter.ip_adapter import MultiIPAdapterImageProjection +from module.ip_adapter.resampler import Resampler +from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds +from module.ip_adapter.attention_processor import init_attn_proc +from utils.train_utils import ( + seperate_ip_params_from_unet, + import_model_class_from_model_name_or_path, + tensor_to_pil, + get_train_dataset, prepare_train_dataset, collate_fn, + encode_prompt, importance_sampling_fn, extract_into_tensor +) +from pipelines.sdxl_instantir import InstantIRPipeline + + +if is_wandb_available(): + import wandb + + +logger = get_logger(__name__) + + +def log_validation(unet, aggregator, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False): + logger.info("Running validation... ") + + image_logs = [] + + # validation_batch = batchify_pil(args.validation_image, args.validation_prompt, deg_pipeline, image_processor) + lq = [Image.open(lq_example).convert("RGB") for lq_example in args.validation_image] + + pipe = InstantIRPipeline( + vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, + unet, scheduler, aggregator, feature_extractor=image_processor, image_encoder=image_encoder, + ).to(accelerator.device) + + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + if lq_img is not None and gt_img is not None: + lq_img = lq_img[:len(args.validation_image)] + lq_pt = image_processor( + images=lq_img*0.5+0.5, + do_rescale=False, return_tensors="pt" + ).pixel_values + image = pipe( + prompt=[""]*len(lq_img), + image=lq_img, + ip_adapter_image=lq_pt, + num_inference_steps=20, + generator=generator, + controlnet_conditioning_scale=1.0, + negative_prompt=[""]*len(lq), + guidance_scale=5.0, + height=args.resolution, + width=args.resolution, + lcm_scheduler=lcm_scheduler, + ).images + else: + image = pipe( + prompt=[""]*len(lq), + image=lq, + ip_adapter_image=lq, + num_inference_steps=20, + generator=generator, + controlnet_conditioning_scale=1.0, + negative_prompt=[""]*len(lq), + guidance_scale=5.0, + height=args.resolution, + width=args.resolution, + lcm_scheduler=lcm_scheduler, + ).images + + if log_local: + for i, rec_image in enumerate(image): + rec_image.save(f"./instantid_{i}.png") + return + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + images = [np.asarray(pil_img) for pil_img in image] + images = np.stack(images, axis=0) + if lq_img is not None and gt_img is not None: + input_lq = lq_img.cpu() + input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1)) + input_gt = gt_img.cpu() + input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1)) + tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW") + tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW") + tracker.writer.add_images("rec", images, step, dataformats="NHWC") + elif tracker.name == "wandb": + raise NotImplementedError("Wandb logging not implemented for validation.") + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def remove_attn2(model): + def recursive_find_module(name, module): + if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return + elif "resnets" in name: return + if hasattr(module, "attn2"): + setattr(module, "attn2", None) + setattr(module, "norm2", None) + return + for sub_name, sub_module in module.named_children(): + recursive_find_module(f"{name}.{sub_name}", sub_module) + + for name, module in model.named_children(): + recursive_find_module(name, module) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a IP-Adapter training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to an pretrained controlnet model like tile-controlnet.", + ) + parser.add_argument( + "--use_lcm", + action="store_true", + help="Whether or not to use lcm unet.", + ) + parser.add_argument( + "--pretrained_lcm_lora_path", + type=str, + default=None, + help="Path to LCM lora or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=64, + help="The rank of the LoRA projection matrix.", + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + parser.add_argument( + "--feature_extractor_path", + type=str, + default=None, + help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_adapter_model_path", + type=str, + default=None, + help="Path to IP-Adapter models or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--adapter_tokens", + type=int, + default=64, + help="Number of tokens to use in IP-adapter cross attention mechanism.", + ) + parser.add_argument( + "--aggregator_adapter", + action="store_true", + help="Whether or not to add adapter on aggregator.", + ) + parser.add_argument( + "--optimize_adapter", + action="store_true", + help="Whether or not to optimize IP-Adapter.", + ) + parser.add_argument( + "--image_encoder_hidden_feature", + action="store_true", + help="Whether or not to use the penultimate hidden states as image embeddings.", + ) + parser.add_argument( + "--losses_config_path", + type=str, + required=True, + help=("A yaml file containing losses to use and their weights."), + ) + parser.add_argument( + "--data_config_path", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="stage1_model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=3000, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co./docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=5, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--previous_ckpt", + type=str, + default=None, + help=( + "Whether training should be initialized from a previous checkpoint." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--save_only_adapter", + action="store_true", + help="Only save extra adapter to save space.", + ) + parser.add_argument( + "--cache_prompt_embeds", + action="store_true", + help="Whether or not to cache prompt embeds to save memory.", + ) + parser.add_argument( + "--importance_sampling", + action="store_true", + help="Whether or not to use importance sampling.", + ) + parser.add_argument( + "--CFG_scale", + type=float, + default=1.0, + help="CFG for previewer.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that ๐Ÿค— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co./docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--text_drop_rate", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--image_drop_rate", + type=float, + default=0, + help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).", + ) + parser.add_argument( + "--cond_drop_rate", + type=float, + default=0, + help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).", + ) + parser.add_argument( + "--use_ema_adapter", + action="store_true", + help=( + "use ema ip-adapter for LCM preview" + ), + ) + parser.add_argument( + "--sanity_check", + action="store_true", + help=( + "sanity check" + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=4000, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default='train', + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co./docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if not args.sanity_check and args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.text_drop_rate < 0 or args.text_drop_rate > 1: + raise ValueError("`--text_drop_rate` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def update_ema_model(ema_model, model, ema_beta): + for ema_param, param in zip(ema_model.parameters(), model.parameters()): + ema_param.copy_(param.detach().lerp(ema_param, ema_beta)) + + +def copy_dict(dict): + new_dict = {} + for key, value in dict.items(): + new_dict[key] = value + return new_dict + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation. + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + + # Importance sampling. + list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64') + prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5) + importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps + importance_ratio = torch.from_numpy(importance_ratio.copy()).float() + + # Load the tokenizers + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_2 = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # Text encoder and image encoder. + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + text_encoder = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_2 = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + + # Image processor and image encoder. + if args.use_clip_encoder: + image_processor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path) + else: + image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path) + image_encoder = AutoModel.from_pretrained(args.feature_extractor_path) + + # VAE. + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + + # UNet. + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.revision, + variant=args.variant + ) + + # Aggregator. + aggregator = Aggregator.from_unet(unet) + remove_attn2(aggregator) + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + if args.controlnet_model_name_or_path.endswith(".safetensors"): + pretrained_cn_state_dict = {} + with safe_open(args.controlnet_model_name_or_path, framework="pt", device='cpu') as f: + for key in f.keys(): + pretrained_cn_state_dict[key] = f.get_tensor(key) + else: + pretrained_cn_state_dict = torch.load(os.path.join(args.controlnet_model_name_or_path, "aggregator_ckpt.pt"), map_location="cpu") + aggregator.load_state_dict(pretrained_cn_state_dict, strict=True) + else: + logger.info("Initializing aggregator weights from unet.") + + # Create image embedding projector for IP-Adapters. + if args.pretrained_adapter_model_path is not None: + if args.pretrained_adapter_model_path.endswith(".safetensors"): + pretrained_adapter_state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(args.pretrained_adapter_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + pretrained_adapter_state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + pretrained_adapter_state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + pretrained_adapter_state_dict = torch.load(args.pretrained_adapter_model_path, map_location="cpu") + + # Image embedding Projector. + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=args.adapter_tokens, + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + ff_mult=4 + ) + + init_adapter_in_unet( + unet, + image_proj_model, + pretrained_adapter_state_dict, + adapter_tokens=args.adapter_tokens, + ) + + # EMA adapter for LCM preview. + if args.use_ema_adapter: + assert args.optimize_adapter, "No need for EMA with frozen adapter." + ema_image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=args.adapter_tokens, + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + ff_mult=4 + ) + orig_encoder_hid_proj = unet.encoder_hid_proj + ema_encoder_hid_proj = MultiIPAdapterImageProjection([ema_image_proj_model]) + orig_attn_procs = unet.attn_processors + orig_attn_procs_list = torch.nn.ModuleList(orig_attn_procs.values()) + ema_attn_procs = init_attn_proc(unet, args.adapter_tokens, True, True, False) + ema_attn_procs_list = torch.nn.ModuleList(ema_attn_procs.values()) + ema_attn_procs_list.requires_grad_(False) + ema_encoder_hid_proj.requires_grad_(False) + + # Initialize EMA state. + ema_beta = 0.5 ** (args.ema_update_steps / max(args.ema_halflife_steps, 1e-8)) + logger.info(f"Using EMA with beta: {ema_beta}") + ema_encoder_hid_proj.load_state_dict(orig_encoder_hid_proj.state_dict()) + ema_attn_procs_list.load_state_dict(orig_attn_procs_list.state_dict()) + + # Projector for aggregator. + if args.aggregator_adapter: + image_proj_model = Resampler( + dim=1280, + depth=4, + dim_head=64, + heads=20, + num_queries=args.adapter_tokens, + embedding_dim=image_encoder.config.hidden_size, + output_dim=unet.config.cross_attention_dim, + ff_mult=4 + ) + + init_adapter_in_unet( + aggregator, + image_proj_model, + pretrained_adapter_state_dict, + adapter_tokens=args.adapter_tokens, + ) + del pretrained_adapter_state_dict + + # Load LCM LoRA into unet. + if args.pretrained_lcm_lora_path is not None: + lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path) + unet_state_dict = { + f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.") + } + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + lora_state_dict = dict() + for k, v in unet_state_dict.items(): + if "ip" in k: + k = k.replace("attn2", "attn2.processor") + lora_state_dict[k] = v + else: + lora_state_dict[k] = v + if alpha_dict: + args.lora_alpha = next(iter(alpha_dict.values())) + else: + args.lora_alpha = 1 + logger.info(f"Loaded LCM LoRA with alpha: {args.lora_alpha}") + # Create LoRA config, FIXME: now hard-coded. + lora_target_modules = [ + "to_q", + "to_kv", + "0.to_out", + "attn1.to_k", + "attn1.to_v", + "to_k_ip", + "to_v_ip", + "ln_k_ip.linear", + "ln_v_ip.linear", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + + unet.add_adapter(lora_config) + if args.pretrained_lcm_lora_path is not None: + incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + missing_keys = getattr(incompatible_keys, "missing_keys", None) + if unexpected_keys: + raise ValueError( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + for k in missing_keys: + if "lora" in k: + raise ValueError( + f"Loading adapter weights from state_dict led to missing keys: {missing_keys}. " + ) + lcm_scheduler = LCMSingleStepScheduler.from_config(noise_scheduler.config) + + # Initialize training state. + vae.requires_grad_(False) + image_encoder.requires_grad_(False) + text_encoder.requires_grad_(False) + text_encoder_2.requires_grad_(False) + unet.requires_grad_(False) + aggregator.requires_grad_(False) + + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if args.save_only_adapter: + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + for model in models: + if isinstance(model, Aggregator): + torch.save(model.state_dict(), os.path.join(output_dir, "aggregator_ckpt.pt")) + weights.pop() + + def load_model_hook(models, input_dir): + + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, Aggregator): + aggregator_state_dict = torch.load(os.path.join(input_dir, "aggregator_ckpt.pt"), map_location="cpu") + model.load_state_dict(aggregator_state_dict) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warning( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co./docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + aggregator.enable_gradient_checkpointing() + unet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(aggregator).dtype != torch.float32: + raise ValueError( + f"aggregator loaded as datatype {unwrap_model(aggregator).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + ip_params, non_ip_params = seperate_ip_params_from_unet(unet) + if args.optimize_adapter: + unet_params = ip_params + unet_frozen_params = non_ip_params + else: # freeze all unet params + unet_params = [] + unet_frozen_params = ip_params + non_ip_params + assert len(unet_frozen_params) == len(list(unet.parameters())) + params_to_optimize = [p for p in aggregator.parameters()] + params_to_optimize += unet_params + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Instantiate Loss. + losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r")) + diffusion_losses = list() + lcm_losses = list() + for loss_config in losses_configs.diffusion_losses: + logger.info(f"Using diffusion loss: {loss_config.name}") + loss = namedtuple("loss", ["loss", "weight"]) + diffusion_losses.append( + loss(loss=eval(loss_config.name)( + visualize_every_k=loss_config.visualize_every_k, + dtype=weight_dtype, + accelerator=accelerator, + **loss_config.init_params), weight=loss_config.weight) + ) + for loss_config in losses_configs.lcm_losses: + logger.info(f"Using lcm loss: {loss_config.name}") + loss = namedtuple("loss", ["loss", "weight"]) + loss_class = eval(loss_config.name) + lcm_losses.append(loss(loss=loss_class(visualize_every_k=loss_config.visualize_every_k, + dtype=weight_dtype, + accelerator=accelerator, + dino_model=image_encoder, + dino_preprocess=image_processor, + **loss_config.init_params), weight=loss_config.weight)) + + # SDXL additional condition that will be added to time embedding. + def compute_time_ids(original_size, crops_coords_top_left): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + target_size = (args.resolution, args.resolution) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + # Text prompt embeddings. + @torch.no_grad() + def compute_embeddings(batch, text_encoders, tokenizers, proportion_empty_prompts=0.0, drop_idx=None, is_train=True): + prompt_batch = batch[args.caption_column] + if drop_idx is not None: + for i in range(len(prompt_batch)): + prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i] + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, is_train + ) + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + ) + + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} + + return prompt_embeds, unet_added_cond_kwargs + + @torch.no_grad() + def get_added_cond(batch, prompt_embeds, pooled_prompt_embeds, proportion_empty_prompts=0.0, drop_idx=None, is_train=True): + + add_time_ids = torch.cat( + [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] + ) + + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} + + return prompt_embeds, unet_added_cond_kwargs + + # Move pixels into latents. + @torch.no_grad() + def convert_to_latent(pixels): + model_input = vae.encode(pixels).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + return model_input + + # Helper functions for training loop. + # if args.degradation_config_path is not None: + # with open(args.degradation_config_path) as stream: + # degradation_configs = yaml.safe_load(stream) + # deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution) + # else: + deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution) + compute_embeddings_fn = functools.partial( + compute_embeddings, + text_encoders=[text_encoder, text_encoder_2], + tokenizers=[tokenizer, tokenizer_2], + is_train=True, + ) + + datasets = [] + datasets_name = [] + datasets_weights = [] + if args.data_config_path is not None: + data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r")) + for single_dataset in data_config.datasets: + datasets_weights.append(single_dataset.dataset_weight) + datasets_name.append(single_dataset.dataset_folder) + dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder) + image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator) + image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline) + datasets.append(image_dataset) + # TODO: Validation dataset + if data_config.val_dataset is not None: + val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator) + logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}") + + # Mix training datasets. + sampler_train = None + if len(datasets) == 1: + train_dataset = datasets[0] + else: + # Weighted each dataset + train_dataset = torch.utils.data.ConcatDataset(datasets) + dataset_weights = [] + for single_dataset, single_weight in zip(datasets, datasets_weights): + dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset)) + sampler_train = torch.utils.data.WeightedRandomSampler( + weights=dataset_weights, + num_samples=len(dataset_weights) + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + sampler=sampler_train, + shuffle=True if sampler_train is None else False, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers + ) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + aggregator, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + aggregator, unet, optimizer, train_dataloader, lr_scheduler + ) + + text_encoder.to(accelerator.device, dtype=weight_dtype) + text_encoder_2.to(accelerator.device, dtype=weight_dtype) + + # # cache empty prompts and move text encoders to cpu + # empty_prompt_embeds, empty_pooled_prompt_embeds = encode_prompt( + # [""]*args.train_batch_size, [text_encoder, text_encoder_2], [tokenizer, tokenizer_2], True + # ) + # compute_embeddings_fn = functools.partial( + # get_added_cond, + # prompt_embeds=empty_prompt_embeds, + # pooled_prompt_embeds=empty_pooled_prompt_embeds, + # is_train=True, + # ) + # text_encoder.to("cpu") + # text_encoder_2.to("cpu") + + # Move vae, unet and text_encoder to device and cast to `weight_dtype`. + if args.pretrained_vae_model_name_or_path is None: + # The VAE is fp32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + else: + vae.to(accelerator.device, dtype=weight_dtype) + image_encoder.to(accelerator.device, dtype=weight_dtype) + if args.use_ema_adapter: + # FIXME: prepare ema model + # ema_encoder_hid_proj, ema_attn_procs_list = accelerator.prepare(ema_encoder_hid_proj, ema_attn_procs_list) + ema_encoder_hid_proj.to(accelerator.device) + ema_attn_procs_list.to(accelerator.device) + for param in unet_frozen_params: + param.data = param.data.to(dtype=weight_dtype) + for param in unet_params: + param.requires_grad_(True) + unet.to(accelerator.device) + aggregator.requires_grad_(True) + aggregator.to(accelerator.device) + importance_ratio = importance_ratio.to(accelerator.device) + + # Final sanity check. + for n, p in unet.named_parameters(): + assert not p.requires_grad, n + if p.requires_grad: + assert p.dtype == torch.float32, n + else: + assert p.dtype == weight_dtype, n + for n, p in aggregator.named_parameters(): + if p.requires_grad: assert p.dtype == torch.float32, n + else: + raise RuntimeError(f"All parameters in aggregator should require grad. {n}") + assert p.dtype == weight_dtype, n + unwrap_model(unet).disable_adapters() + + if args.sanity_check: + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + + if args.use_ema_adapter: + unwrap_model(unet).set_attn_processor(ema_attn_procs) + unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj + batch = next(iter(train_dataloader)) + lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"])) + log_validation( + unwrap_model(unet), unwrap_model(aggregator), vae, + text_encoder, text_encoder_2, tokenizer, tokenizer_2, + noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, log_local=True + ) + if args.use_ema_adapter: + unwrap_model(unet).set_attn_processor(orig_attn_procs) + unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj + for n, p in unet.named_parameters(): + if p.requires_grad: assert p.dtype == torch.float32, n + else: assert p.dtype == weight_dtype, n + for n, p in aggregator.named_parameters(): + if p.requires_grad: assert p.dtype == torch.float32, n + else: assert p.dtype == weight_dtype, n + print("PASS") + exit() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + trainable_models = [aggregator, unet] + + if args.gradient_checkpointing: + # TODO: add vae + checkpoint_models = [] + else: + checkpoint_models = [] + + image_logs = None + tic = time.time() + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + toc = time.time() + io_time = toc - tic + tic = time.time() + for model in trainable_models + checkpoint_models: + model.train() + with accelerator.accumulate(*trainable_models): + loss = torch.tensor(0.0) + + # Drop conditions. + rand_tensor = torch.rand(batch["images"].shape[0]) + drop_image_idx = rand_tensor < args.image_drop_rate + drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate) + drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate) + drop_image_idx = drop_image_idx | drop_both_idx + drop_text_idx = drop_text_idx | drop_both_idx + + # Get LQ embeddings + with torch.no_grad(): + lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"])) + lq_pt = image_processor( + images=lq_img*0.5+0.5, + do_rescale=False, return_tensors="pt" + ).pixel_values + + # Move inputs to latent space. + gt_img = gt_img.to(dtype=vae.dtype) + lq_img = lq_img.to(dtype=vae.dtype) + model_input = convert_to_latent(gt_img) + lq_latent = convert_to_latent(lq_img) + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + lq_latent = lq_latent.to(weight_dtype) + + # Process conditions. + image_embeds = prepare_training_image_embeds( + image_encoder, image_processor, + ip_adapter_image=lq_pt, ip_adapter_image_embeds=None, + device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature, + idx_to_replace=drop_image_idx + ) + prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx) + + # Sample noise that we'll add to the latents. + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device) + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None + + if args.CFG_scale > 1.0: + # Process negative conditions. + drop_idx = torch.ones_like(drop_image_idx) + neg_image_embeds = prepare_training_image_embeds( + image_encoder, image_processor, + ip_adapter_image=lq_pt, ip_adapter_image_embeds=None, + device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature, + idx_to_replace=drop_idx + ) + neg_prompt_embeds_input, neg_added_conditions = compute_embeddings_fn(batch, drop_idx=drop_idx) + previewer_model_input = torch.cat([noisy_model_input] * 2) + previewer_timesteps = torch.cat([timesteps] * 2) + previewer_prompt_embeds = torch.cat([neg_prompt_embeds_input, prompt_embeds_input], dim=0) + previewer_added_conditions = {} + for k in added_conditions.keys(): + previewer_added_conditions[k] = torch.cat([neg_added_conditions[k], added_conditions[k]], dim=0) + previewer_image_embeds = [] + for neg_image_embed, image_embed in zip(neg_image_embeds, image_embeds): + previewer_image_embeds.append(torch.cat([neg_image_embed, image_embed], dim=0)) + previewer_added_conditions["image_embeds"] = previewer_image_embeds + else: + previewer_model_input = noisy_model_input + previewer_timesteps = timesteps + previewer_prompt_embeds = prompt_embeds_input + previewer_added_conditions = {} + for k in added_conditions.keys(): + previewer_added_conditions[k] = added_conditions[k] + previewer_added_conditions["image_embeds"] = image_embeds + + # Get LCM preview latent + if args.use_ema_adapter: + orig_encoder_hid_proj = unet.encoder_hid_proj + orig_attn_procs = unet.attn_processors + _ema_attn_procs = copy_dict(ema_attn_procs) + unwrap_model(unet).set_attn_processor(_ema_attn_procs) + unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj + unwrap_model(unet).enable_adapters() + preview_noise = unet( + previewer_model_input, previewer_timesteps, + encoder_hidden_states=previewer_prompt_embeds, + added_cond_kwargs=previewer_added_conditions, + return_dict=False + )[0] + if args.CFG_scale > 1.0: + preview_noise_uncond, preview_noise_cond = preview_noise.chunk(2) + cfg_scale = 1.0 + torch.rand_like(timesteps, dtype=weight_dtype) * (args.CFG_scale-1.0) + cfg_scale = cfg_scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + preview_noise = preview_noise_uncond + cfg_scale * (preview_noise_cond - preview_noise_uncond) + preview_latents = lcm_scheduler.step( + preview_noise, + timesteps, + noisy_model_input, + return_dict=False + )[0] + unwrap_model(unet).disable_adapters() + if args.use_ema_adapter: + unwrap_model(unet).set_attn_processor(orig_attn_procs) + unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj + preview_error_latent = F.mse_loss(preview_latents, model_input).cpu().item() + preview_error_noise = F.mse_loss(preview_noise, noise).cpu().item() + + # # Add fresh noise + # if args.noisy_encoder_input: + # aggregator_noise = torch.randn_like(preview_latents) + # aggregator_input = noise_scheduler.add_noise(preview_latents, aggregator_noise, timesteps) + + down_block_res_samples, mid_block_res_sample = aggregator( + lq_latent, + timesteps, + encoder_hidden_states=prompt_embeds_input, + added_cond_kwargs=added_conditions, + controlnet_cond=preview_latents, + conditioning_scale=1.0, + return_dict=False, + ) + + # UNet denoise. + added_conditions["image_embeds"] = image_embeds + model_pred = unet( + noisy_model_input, + timesteps, + encoder_hidden_states=prompt_embeds_input, + added_cond_kwargs=added_conditions, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + return_dict=False + )[0] + + diffusion_loss_arguments = { + "target": noise, + "predict": model_pred, + "prompt_embeddings_input": prompt_embeds_input, + "timesteps": timesteps, + "weights": loss_weights, + } + + loss_dict = dict() + for loss_config in diffusion_losses: + non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator) + loss = loss + non_weighted_loss * loss_config.weight + loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item() + + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + toc = time.time() + forward_time = toc - tic + tic = toc + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if global_step % args.ema_update_steps == 0: + if args.use_ema_adapter: + update_ema_model(ema_encoder_hid_proj, orig_encoder_hid_proj, ema_beta) + update_ema_model(ema_attn_procs_list, orig_attn_procs_list, ema_beta) + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + image_logs = log_validation( + unwrap_model(unet), unwrap_model(aggregator), vae, + text_encoder, text_encoder_2, tokenizer, tokenizer_2, + noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, global_step, lq_img.detach().clone(), gt_img.detach().clone() + ) + + logs = {} + logs.update(loss_dict) + logs.update( + {"preview_error_latent": preview_error_latent, "preview_error_noise": preview_error_noise, + "lr": lr_scheduler.get_last_lr()[0], + "forward_time": forward_time, "io_time": io_time} + ) + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + tic = time.time() + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + accelerator.save_state(save_path, safe_serialization=False) + # Run a final round of validation. + # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`. + image_logs = None + if args.validation_image is not None: + image_logs = log_validation( + unwrap_model(unet), unwrap_model(aggregator), vae, + text_encoder, text_encoder_2, tokenizer, tokenizer_2, + noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline, + args, accelerator, weight_dtype, global_step, + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) \ No newline at end of file diff --git a/train_stage2_aggregator.sh b/train_stage2_aggregator.sh new file mode 100644 index 0000000000000000000000000000000000000000..768821eec2afb564be1094c10eb63632e5d741d1 --- /dev/null +++ b/train_stage2_aggregator.sh @@ -0,0 +1,24 @@ +# Stage 2: train aggregator +accelerate launch --num_processes train_stage2_aggregator.py \ + --output_dir \ + --train_data_dir \ + --logging_dir \ + --pretrained_model_name_or_path \ + --feature_extractor_path \ + --pretrained_adapter_model_path \ + --pretrained_lcm_lora_path \ + --losses_config_path config_files/losses.yaml \ + --data_config_path config_files/IR_dataset.yaml \ + --image_drop_rate 0.0 \ + --text_drop_rate 0.85 \ + --cond_drop_rate 0.15 \ + --save_only_adapter \ + --gradient_checkpointing \ + --mixed_precision fp16 \ + --train_batch_size 6 \ + --gradient_accumulation_steps 2 \ + --learning_rate 1e-4 \ + --lr_warmup_steps 1000 \ + --lr_scheduler cosine \ + --lr_num_cycles 1 \ + --resume_from_checkpoint latest \ No newline at end of file diff --git a/utils/degradation_pipeline.py b/utils/degradation_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..4cd1ca9b6b8df12003d7a14862f8f795bd5c83c0 --- /dev/null +++ b/utils/degradation_pipeline.py @@ -0,0 +1,353 @@ +import cv2 +import math +import numpy as np +import random +import torch +from torch.utils import data as data + +from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels +from basicsr.data.transforms import augment +from basicsr.utils import img2tensor, DiffJPEG, USMSharp +from basicsr.utils.img_process_util import filter2D +from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt +from basicsr.data.transforms import paired_random_crop + +AUGMENT_OPT = { + 'use_hflip': False, + 'use_rot': False +} + +KERNEL_OPT = { + 'blur_kernel_size': 21, + 'kernel_list': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'], + 'kernel_prob': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03], + 'sinc_prob': 0.1, + 'blur_sigma': [0.2, 3], + 'betag_range': [0.5, 4], + 'betap_range': [1, 2], + + 'blur_kernel_size2': 21, + 'kernel_list2': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'], + 'kernel_prob2': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03], + 'sinc_prob2': 0.1, + 'blur_sigma2': [0.2, 1.5], + 'betag_range2': [0.5, 4], + 'betap_range2': [1, 2], + 'final_sinc_prob': 0.8, +} + +DEGRADE_OPT = { + 'resize_prob': [0.2, 0.7, 0.1], # up, down, keep + 'resize_range': [0.15, 1.5], + 'gaussian_noise_prob': 0.5, + 'noise_range': [1, 30], + 'poisson_scale_range': [0.05, 3], + 'gray_noise_prob': 0.4, + 'jpeg_range': [30, 95], + + # the second degradation process + 'second_blur_prob': 0.8, + 'resize_prob2': [0.3, 0.4, 0.3], # up, down, keep + 'resize_range2': [0.3, 1.2], + 'gaussian_noise_prob2': 0.5, + 'noise_range2': [1, 25], + 'poisson_scale_range2': [0.05, 2.5], + 'gray_noise_prob2': 0.4, + 'jpeg_range2': [30, 95], + + 'gt_size': 512, + 'no_degradation_prob': 0.01, + 'use_usm': True, + 'sf': 4, + 'random_size': False, + 'resize_lq': True +} + +class RealESRGANDegradation: + + def __init__(self, augment_opt=None, kernel_opt=None, degrade_opt=None, device='cuda', resolution=None): + if augment_opt is None: + augment_opt = AUGMENT_OPT + self.augment_opt = augment_opt + if kernel_opt is None: + kernel_opt = KERNEL_OPT + self.kernel_opt = kernel_opt + if degrade_opt is None: + degrade_opt = DEGRADE_OPT + self.degrade_opt = degrade_opt + if resolution is not None: + self.degrade_opt['gt_size'] = resolution + self.device = device + + self.jpeger = DiffJPEG(differentiable=False).to(self.device) + self.usm_sharpener = USMSharp().to(self.device) + + # blur settings for the first degradation + self.blur_kernel_size = kernel_opt['blur_kernel_size'] + self.kernel_list = kernel_opt['kernel_list'] + self.kernel_prob = kernel_opt['kernel_prob'] # a list for each kernel probability + self.blur_sigma = kernel_opt['blur_sigma'] + self.betag_range = kernel_opt['betag_range'] # betag used in generalized Gaussian blur kernels + self.betap_range = kernel_opt['betap_range'] # betap used in plateau blur kernels + self.sinc_prob = kernel_opt['sinc_prob'] # the probability for sinc filters + + # blur settings for the second degradation + self.blur_kernel_size2 = kernel_opt['blur_kernel_size2'] + self.kernel_list2 = kernel_opt['kernel_list2'] + self.kernel_prob2 = kernel_opt['kernel_prob2'] + self.blur_sigma2 = kernel_opt['blur_sigma2'] + self.betag_range2 = kernel_opt['betag_range2'] + self.betap_range2 = kernel_opt['betap_range2'] + self.sinc_prob2 = kernel_opt['sinc_prob2'] + + # a final sinc filter + self.final_sinc_prob = kernel_opt['final_sinc_prob'] + + self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 + # TODO: kernel range is now hard-coded, should be in the configure file + self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect + self.pulse_tensor[10, 10] = 1 + + def get_kernel(self): + + # ------------------------ Generate kernels (used in the first degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.kernel_opt['sinc_prob']: + # this sinc filter setting is for kernels ranging from [7, 21] + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel = random_mixed_kernels( + self.kernel_list, + self.kernel_prob, + kernel_size, + self.blur_sigma, + self.blur_sigma, [-math.pi, math.pi], + self.betag_range, + self.betap_range, + noise_range=None) + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------ Generate kernels (used in the second degradation) ------------------------ # + kernel_size = random.choice(self.kernel_range) + if np.random.uniform() < self.kernel_opt['sinc_prob2']: + if kernel_size < 13: + omega_c = np.random.uniform(np.pi / 3, np.pi) + else: + omega_c = np.random.uniform(np.pi / 5, np.pi) + kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) + else: + kernel2 = random_mixed_kernels( + self.kernel_list2, + self.kernel_prob2, + kernel_size, + self.blur_sigma2, + self.blur_sigma2, [-math.pi, math.pi], + self.betag_range2, + self.betap_range2, + noise_range=None) + + # pad kernel + pad_size = (21 - kernel_size) // 2 + kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) + + # ------------------------------------- the final sinc kernel ------------------------------------- # + if np.random.uniform() < self.kernel_opt['final_sinc_prob']: + kernel_size = random.choice(self.kernel_range) + omega_c = np.random.uniform(np.pi / 3, np.pi) + sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) + sinc_kernel = torch.FloatTensor(sinc_kernel) + else: + sinc_kernel = self.pulse_tensor + + # BGR to RGB, HWC to CHW, numpy to tensor + kernel = torch.FloatTensor(kernel) + kernel2 = torch.FloatTensor(kernel2) + + return (kernel, kernel2, sinc_kernel) + + @torch.no_grad() + def __call__(self, img_gt, kernels=None): + ''' + :param: img_gt: BCHW, RGB, [0, 1] float32 tensor + ''' + if kernels is None: + kernel = [] + kernel2 = [] + sinc_kernel = [] + for _ in range(img_gt.shape[0]): + k, k2, sk = self.get_kernel() + kernel.append(k) + kernel2.append(k2) + sinc_kernel.append(sk) + kernel = torch.stack(kernel) + kernel2 = torch.stack(kernel2) + sinc_kernel = torch.stack(sinc_kernel) + else: + # kernels created in dataset. + kernel, kernel2, sinc_kernel = kernels + + # ----------------------- Pre-process ----------------------- # + im_gt = img_gt.to(self.device) + if self.degrade_opt['use_usm']: + im_gt = self.usm_sharpener(im_gt) + im_gt = im_gt.to(memory_format=torch.contiguous_format).float() + kernel = kernel.to(self.device) + kernel2 = kernel2.to(self.device) + sinc_kernel = sinc_kernel.to(self.device) + ori_h, ori_w = im_gt.size()[2:4] + + # ----------------------- The first degradation process ----------------------- # + # blur + out = filter2D(im_gt, kernel) + # random resize + updown_type = random.choices( + ['up', 'down', 'keep'], + self.degrade_opt['resize_prob'], + )[0] + if updown_type == 'up': + scale = random.uniform(1, self.degrade_opt['resize_range'][1]) + elif updown_type == 'down': + scale = random.uniform(self.degrade_opt['resize_range'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = torch.nn.functional.interpolate(out, scale_factor=scale, mode=mode) + # add noise + gray_noise_prob = self.degrade_opt['gray_noise_prob'] + if random.random() < self.degrade_opt['gaussian_noise_prob']: + out = random_add_gaussian_noise_pt( + out, + sigma_range=self.degrade_opt['noise_range'], + clip=True, + rounds=False, + gray_prob=gray_noise_prob, + ) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.degrade_opt['poisson_scale_range'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range']) + out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts + out = self.jpeger(out, quality=jpeg_p) + + # ----------------------- The second degradation process ----------------------- # + # blur + if random.random() < self.degrade_opt['second_blur_prob']: + out = out.contiguous() + out = filter2D(out, kernel2) + # random resize + updown_type = random.choices( + ['up', 'down', 'keep'], + self.degrade_opt['resize_prob2'], + )[0] + if updown_type == 'up': + scale = random.uniform(1, self.degrade_opt['resize_range2'][1]) + elif updown_type == 'down': + scale = random.uniform(self.degrade_opt['resize_range2'][0], 1) + else: + scale = 1 + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = torch.nn.functional.interpolate( + out, + size=(int(ori_h / self.degrade_opt['sf'] * scale), + int(ori_w / self.degrade_opt['sf'] * scale)), + mode=mode, + ) + # add noise + gray_noise_prob = self.degrade_opt['gray_noise_prob2'] + if random.random() < self.degrade_opt['gaussian_noise_prob2']: + out = random_add_gaussian_noise_pt( + out, + sigma_range=self.degrade_opt['noise_range2'], + clip=True, + rounds=False, + gray_prob=gray_noise_prob, + ) + else: + out = random_add_poisson_noise_pt( + out, + scale_range=self.degrade_opt['poisson_scale_range2'], + gray_prob=gray_noise_prob, + clip=True, + rounds=False, + ) + + # JPEG compression + the final sinc filter + # We also need to resize images to desired sizes. We group [resize back + sinc filter] together + # as one operation. + # We consider two orders: + # 1. [resize back + sinc filter] + JPEG compression + # 2. JPEG compression + [resize back + sinc filter] + # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. + if random.random() < 0.5: + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = torch.nn.functional.interpolate( + out, + size=(ori_h // self.degrade_opt['sf'], + ori_w // self.degrade_opt['sf']), + mode=mode, + ) + out = out.contiguous() + out = filter2D(out, sinc_kernel) + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + else: + # JPEG compression + jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2']) + out = torch.clamp(out, 0, 1) + out = self.jpeger(out, quality=jpeg_p) + # resize back + the final sinc filter + mode = random.choice(['area', 'bilinear', 'bicubic']) + out = torch.nn.functional.interpolate( + out, + size=(ori_h // self.degrade_opt['sf'], + ori_w // self.degrade_opt['sf']), + mode=mode, + ) + out = out.contiguous() + out = filter2D(out, sinc_kernel) + + # clamp and round + im_lq = torch.clamp(out, 0, 1.0) + + # random crop + gt_size = self.degrade_opt['gt_size'] + im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.degrade_opt['sf']) + + if self.degrade_opt['resize_lq']: + im_lq = torch.nn.functional.interpolate( + im_lq, + size=(im_gt.size(-2), + im_gt.size(-1)), + mode='bicubic', + ) + + if random.random() < self.degrade_opt['no_degradation_prob'] or torch.isnan(im_lq).any(): + im_lq = im_gt + + # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue + im_lq = im_lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract + im_lq = im_lq*2 - 1.0 + im_gt = im_gt*2 - 1.0 + + if self.degrade_opt['random_size']: + raise NotImplementedError + im_lq, im_gt = self.randn_cropinput(im_lq, im_gt) + + im_lq = torch.clamp(im_lq, -1.0, 1.0) + im_gt = torch.clamp(im_gt, -1.0, 1.0) + + return (im_lq, im_gt) \ No newline at end of file diff --git a/utils/matlab_cp2tform.py b/utils/matlab_cp2tform.py new file mode 100644 index 0000000000000000000000000000000000000000..cdcdf96ab45577bcea54c809b4d241c9e8a6a74e --- /dev/null +++ b/utils/matlab_cp2tform.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +""" +Created on Tue Jul 11 06:54:28 2017 + +@author: zhaoyafei +""" + +import numpy as np +from numpy.linalg import inv, norm, lstsq +from numpy.linalg import matrix_rank as rank + +class MatlabCp2tormException(Exception): + def __str__(self): + return 'In File {}:{}'.format( + __file__, super.__str__(self)) + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack(( + uv, np.ones((uv.shape[0], 1)) + )) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + # print('--->x, y:\n', x, y + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + # print('--->X.shape: ', X.shape + # print('X:\n', X + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + # print('--->U.shape: ', U.shape + # print('U:\n', U + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U) + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + + # print('--->r:\n', r + + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([ + [sc, -ss, 0], + [ss, sc, 0], + [tx, ty, 1] + ]) + + # print('--->Tinv:\n', Tinv + + T = inv(Tinv) + # print('--->T:\n', T + + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + + options = {'K': 2} + +# uv = np.array(uv) +# xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([ + [-1, 0, 0], + [0, 1, 0], + [0, 0, 1] + ]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack(( + uv, np.ones((uv.shape[0], 1)) + )) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack(( + xy, np.ones((xy.shape[0], 1)) + )) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/utils/parser.py b/utils/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..3830be1212234577d6cd86a293b1285269008537 --- /dev/null +++ b/utils/parser.py @@ -0,0 +1,452 @@ +import argparse +import os + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Train Consistency Encoder.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + + # parser.add_argument( + # "--instance_data_dir", + # type=str, + # required=True, + # help=("A folder containing the training data. "), + # ) + + parser.add_argument( + "--data_config_path", + type=str, + required=True, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_train_vis_images", + type=int, + default=2, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=2, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + + parser.add_argument( + "--validation_vis_steps", + type=int, + default=500, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + + parser.add_argument( + "--train_vis_steps", + type=int, + default=500, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + + parser.add_argument( + "--vis_lcm", + type=bool, + default=True, + help=( + "Also log results of LCM inference", + ), + ) + + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + + parser.add_argument("--save_only_encoder", action="store_true", help="Only save the encoder and not the full accelerator state") + + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + + parser.add_argument("--freeze_encoder_unet", action="store_true", help="Don't train encoder unet") + parser.add_argument("--predict_word_embedding", action="store_true", help="Predict word embeddings in addition to KV features") + parser.add_argument("--ip_adapter_feature_extractor_path", type=str, help="Path to pre-trained feature extractor for IP-adapter") + parser.add_argument("--ip_adapter_model_path", type=str, help="Path to pre-trained IP-adapter.") + parser.add_argument("--ip_adapter_tokens", type=int, default=16, help="Number of tokens to use in IP-adapter cross attention mechanism") + parser.add_argument("--optimize_adapter", action="store_true", help="Optimize IP-adapter parameters (projector + cross-attention layers)") + parser.add_argument("--adapter_attention_scale", type=float, default=1.0, help="Relative strength of the adapter cross attention layers") + parser.add_argument("--adapter_lr", type=float, help="Learning rate for the adapter parameters. Defaults to the global LR if not provided") + + parser.add_argument("--noisy_encoder_input", action="store_true", help="Noise the encoder input to the same step as the decoder?") + + # related to CFG: + parser.add_argument("--adapter_drop_chance", type=float, default=0.0, help="Chance to drop adapter condition input during training") + parser.add_argument("--text_drop_chance", type=float, default=0.0, help="Chance to drop text condition during training") + parser.add_argument("--kv_drop_chance", type=float, default=0.0, help="Chance to drop KV condition during training") + + + + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + + parser.add_argument("--num_train_epochs", type=int, default=1) + + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=5, + help=("Max number of checkpoints to store."), + ) + + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + + parser.add_argument("--max_timesteps_for_x0_loss", type=int, default=1001) + + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + + parser.add_argument( + "--pretrained_lcm_lora_path", + type=str, + default="latent-consistency/lcm-lora-sdxl", + help=("Path for lcm lora pretrained"), + ) + + parser.add_argument( + "--losses_config_path", + type=str, + required=True, + help=("A yaml file containing losses to use and their weights."), + ) + + parser.add_argument( + "--lcm_every_k_steps", + type=int, + default=-1, + help="How often to run lcm. If -1, lcm is not run." + ) + + parser.add_argument( + "--lcm_batch_size", + type=int, + default=1, + help="Batch size for lcm." + ) + parser.add_argument( + "--lcm_max_timestep", + type=int, + default=1000, + help="Max timestep to use with LCM." + ) + + parser.add_argument( + "--lcm_sample_scale_every_k_steps", + type=int, + default=-1, + help="How often to change lcm scale. If -1, scale is fixed at 1." + ) + + parser.add_argument( + "--lcm_min_scale", + type=float, + default=0.1, + help="When sampling lcm scale, the minimum scale to use." + ) + + parser.add_argument( + "--scale_lcm_by_max_step", + action="store_true", + help="scale LCM lora alpha linearly by the maximal timestep sampled that iteration" + ) + + parser.add_argument( + "--lcm_sample_full_lcm_prob", + type=float, + default=0.2, + help="When sampling lcm scale, the probability of using full lcm (scale of 1)." + ) + + parser.add_argument( + "--run_on_cpu", + action="store_true", + help="whether to run on cpu or not" + ) + + parser.add_argument( + "--experiment_name", + type=str, + help=("A short description of the experiment to add to the wand run log. "), + ) + parser.add_argument("--encoder_lora_rank", type=int, default=0, help="Rank of Lora in unet encoder. 0 means no lora") + + parser.add_argument("--kvcopy_lora_rank", type=int, default=0, help="Rank of lora in the kvcopy modules. 0 means no lora") + + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + args.optimizer = "AdamW" + + return args \ No newline at end of file diff --git a/utils/text_utils.py b/utils/text_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1326275f969ab5117f31b508c2e30b1438f893 --- /dev/null +++ b/utils/text_utils.py @@ -0,0 +1,76 @@ +import torch + +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def add_tokens(tokenizers, tokens, text_encoders): + new_token_indices = {} + for idx, tokenizer in enumerate(tokenizers): + for token in tokens: + num_added_tokens = tokenizer.add_tokens(token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + + new_token_indices[f"{idx}_{token}"] = num_added_tokens + # resize embedding layers to avoid crash. We will never actually use these. + text_encoders[idx].resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128) + + return new_token_indices + + +def patch_embedding_forward(embedding_layer, new_tokens, new_embeddings): + + def new_forward(input): + embedded_text = torch.nn.functional.embedding( + input, embedding_layer.weight, embedding_layer.padding_idx, embedding_layer.max_norm, + embedding_layer.norm_type, embedding_layer.scale_grad_by_freq, embedding_layer.sparse) + + replace_indices = (input == new_tokens) + + if torch.count_nonzero(replace_indices) > 0: + embedded_text[replace_indices] = new_embeddings + + return embedded_text + + embedding_layer.forward = new_forward \ No newline at end of file diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..02d6cd734dbd1cbe64b75e19f61bb5ddd782ab76 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,360 @@ +import argparse +import contextlib +import time +import gc +import logging +import math +import os +import random +import jsonlines +import functools +import shutil +import pyrallis +import itertools +from pathlib import Path +from collections import namedtuple, OrderedDict + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from datasets import load_dataset +from packaging import version +from PIL import Image +from losses.losses import * +from torchvision import transforms +from torchvision.transforms.functional import crop +from tqdm.auto import tqdm + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + from transformers import PretrainedConfig + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + +def get_train_dataset(dataset_name, dataset_dir, args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + dataset = load_dataset( + dataset_name, + data_dir=dataset_dir, + cache_dir=os.path.join(dataset_dir, ".cache"), + num_proc=4, + split="train", + ) + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset.column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + args.image_column = column_names[0] + logger.info(f"image column defaulting to {column_names[0]}") + else: + image_column = args.image_column + if image_column not in column_names: + logger.warning(f"dataset {dataset_name} has no column {image_column}") + + if args.caption_column is None: + args.caption_column = column_names[1] + logger.info(f"caption column defaulting to {column_names[1]}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + logger.warning(f"dataset {dataset_name} has no column {caption_column}") + + if args.conditioning_image_column is None: + args.conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {column_names[2]}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + logger.warning(f"dataset {dataset_name} has no column {conditioning_image_column}") + + with accelerator.main_process_first(): + train_dataset = dataset.shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + +def prepare_train_dataset(dataset, accelerator, deg_pipeline, centralize=False): + + # Data augmentations. + hflip = deg_pipeline.augment_opt['use_hflip'] and random.random() < 0.5 + vflip = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5 + rot90 = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5 + augment_transforms = [] + if hflip: + augment_transforms.append(transforms.RandomHorizontalFlip(p=1.0)) + if vflip: + augment_transforms.append(transforms.RandomVerticalFlip(p=1.0)) + if rot90: + # FIXME + augment_transforms.append(transforms.RandomRotation(degrees=(90,90))) + torch_transforms=[transforms.ToTensor()] + if centralize: + # to [-1, 1] + torch_transforms.append(transforms.Normalize([0.5], [0.5])) + + training_size = deg_pipeline.degrade_opt['gt_size'] + image_transforms = transforms.Compose(augment_transforms) + train_transforms = transforms.Compose(torch_transforms) + train_resize = transforms.Resize(training_size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.RandomCrop(training_size) + + def preprocess_train(examples): + raw_images = [] + for img_data in examples[args.image_column]: + raw_images.append(Image.open(img_data).convert("RGB")) + + # Image stack. + images = [] + original_sizes = [] + crop_top_lefts = [] + # Degradation kernels stack. + kernel = [] + kernel2 = [] + sinc_kernel = [] + + for raw_image in raw_images: + raw_image = image_transforms(raw_image) + original_sizes.append((raw_image.height, raw_image.width)) + + # Resize smaller edge. + raw_image = train_resize(raw_image) + # Crop to training size. + y1, x1, h, w = train_crop.get_params(raw_image, (training_size, training_size)) + raw_image = crop(raw_image, y1, x1, h, w) + crop_top_left = (y1, x1) + crop_top_lefts.append(crop_top_left) + image = train_transforms(raw_image) + + images.append(image) + k, k2, sk = deg_pipeline.get_kernel() + kernel.append(k) + kernel2.append(k2) + sinc_kernel.append(sk) + + examples["images"] = images + examples["original_sizes"] = original_sizes + examples["crop_top_lefts"] = crop_top_lefts + examples["kernel"] = kernel + examples["kernel2"] = kernel2 + examples["sinc_kernel"] = sinc_kernel + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + +def collate_fn(examples): + images = torch.stack([example["images"] for example in examples]) + images = images.to(memory_format=torch.contiguous_format).float() + kernel = torch.stack([example["kernel"] for example in examples]) + kernel = kernel.to(memory_format=torch.contiguous_format).float() + kernel2 = torch.stack([example["kernel2"] for example in examples]) + kernel2 = kernel2.to(memory_format=torch.contiguous_format).float() + sinc_kernel = torch.stack([example["sinc_kernel"] for example in examples]) + sinc_kernel = sinc_kernel.to(memory_format=torch.contiguous_format).float() + original_sizes = [example["original_sizes"] for example in examples] + crop_top_lefts = [example["crop_top_lefts"] for example in examples] + + prompts = [] + for example in examples: + prompts.append(example[args.caption_column]) if args.caption_column in example else prompts.append("") + + return { + "images": images, + "text": prompts, + "kernel": kernel, + "kernel2": kernel2, + "sinc_kernel": sinc_kernel, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } + +def encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + +def importance_sampling_fn(t, max_t, alpha): + """Importance Sampling Function f(t)""" + return 1 / max_t * (1 - alpha * np.cos(np.pi * t / max_t)) + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def tensor_to_pil(images): + """ + Convert image tensor or a batch of image tensors to PIL image(s). + """ + images = (images + 1) / 2 + images_np = images.detach().cpu().numpy() + if images_np.ndim == 4: + images_np = np.transpose(images_np, (0, 2, 3, 1)) + elif images_np.ndim == 3: + images_np = np.transpose(images_np, (1, 2, 0)) + images_np = images_np[None, ...] + images_np = (images_np * 255).round().astype("uint8") + if images_np.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np] + else: + pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np] + + return pil_images + +def save_np_to_image(img_np, save_dir): + img_np = np.transpose(img_np, (0, 2, 3, 1)) + img_np = (img_np * 255).astype(np.uint8) + img_np = Image.fromarray(img_np[0]) + img_np.save(save_dir) + + +def seperate_SFT_params_from_unet(unet): + params = [] + non_params = [] + for name, param in unet.named_parameters(): + if "SFT" in name: + params.append(param) + else: + non_params.append(param) + return params, non_params + + +def seperate_lora_params_from_unet(unet): + keys = [] + frozen_keys = [] + for name, param in unet.named_parameters(): + if "lora" in name: + keys.append(param) + else: + frozen_keys.append(param) + return keys, frozen_keys + + +def seperate_ip_params_from_unet(unet): + ip_params = [] + non_ip_params = [] + for name, param in unet.named_parameters(): + if "encoder_hid_proj." in name or "_ip." in name: + ip_params.append(param) + elif "attn" in name and "processor" in name: + if "ip" in name or "ln" in name: + ip_params.append(param) + else: + non_ip_params.append(param) + return ip_params, non_ip_params + + +def seperate_ref_params_from_unet(unet): + ip_params = [] + non_ip_params = [] + for name, param in unet.named_parameters(): + if "encoder_hid_proj." in name or "_ip." in name: + ip_params.append(param) + elif "attn" in name and "processor" in name: + if "ip" in name or "ln" in name: + ip_params.append(param) + elif "extract" in name: + ip_params.append(param) + else: + non_ip_params.append(param) + return ip_params, non_ip_params + + +def seperate_ip_modules_from_unet(unet): + ip_modules = [] + non_ip_modules = [] + for name, module in unet.named_modules(): + if "encoder_hid_proj" in name or "attn2.processor" in name: + ip_modules.append(module) + else: + non_ip_modules.append(module) + return ip_modules, non_ip_modules + + +def seperate_SFT_keys_from_unet(unet): + keys = [] + non_keys = [] + for name, param in unet.named_parameters(): + if "SFT" in name: + keys.append(name) + else: + non_keys.append(name) + return keys, non_keys + + +def seperate_ip_keys_from_unet(unet): + keys = [] + non_keys = [] + for name, param in unet.named_parameters(): + if "encoder_hid_proj." in name or "_ip." in name: + keys.append(name) + elif "attn" in name and "processor" in name: + if "ip" in name or "ln" in name: + keys.append(name) + else: + non_keys.append(name) + return keys, non_keys \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e85c7a43b46fd305c643e1c1b387e7490e5dcc5 --- /dev/null +++ b/utils/utils.py @@ -0,0 +1,51 @@ +import torch +import numpy as np +from einops import rearrange +from kornia.geometry.transform.crop2d import warp_affine + +from utils.matlab_cp2tform import get_similarity_transform_for_cv2 +from torchvision.transforms import Pad + +REFERNCE_FACIAL_POINTS_RELATIVE = np.array([[38.29459953, 51.69630051], + [72.53179932, 51.50139999], + [56.02519989, 71.73660278], + [41.54930115, 92.3655014], + [70.72990036, 92.20410156] + ]) / 112 # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112 + + +@torch.no_grad() +def detect_face(images: torch.Tensor, mtcnn: torch.nn.Module) -> torch.Tensor: + """ + Detect faces in the images using MTCNN. If no face is detected, use the whole image. + """ + images = rearrange(images, "b c h w -> b h w c") + if images.dtype != torch.uint8: + images = ((images * 0.5 + 0.5) * 255).type(torch.uint8) # Unnormalize + + _, _, landmarks = mtcnn(images, landmarks=True) + + return landmarks + + +def extract_faces_and_landmarks(images: torch.Tensor, output_size=112, mtcnn: torch.nn.Module = None, refernce_points=REFERNCE_FACIAL_POINTS_RELATIVE): + """ + detect faces in the images and crop them (in a differentiable way) to 112x112 using MTCNN. + """ + images = Pad(200)(images) + landmarks_batched = detect_face(images, mtcnn=mtcnn) + affine_transformations = [] + invalid_indices = [] + for i, landmarks in enumerate(landmarks_batched): + if landmarks is None: + invalid_indices.append(i) + affine_transformations.append(np.eye(2, 3).astype(np.float32)) + else: + affine_transformations.append(get_similarity_transform_for_cv2(landmarks[0].astype(np.float32), + refernce_points.astype(np.float32) * output_size)) + affine_transformations = torch.from_numpy(np.stack(affine_transformations).astype(np.float32)).to(device=images.device, dtype=torch.float32) + + invalid_indices = torch.tensor(invalid_indices).to(device=images.device) + + fp_images = images.to(torch.float32) + return warp_affine(fp_images, affine_transformations, dsize=(output_size, output_size)).to(dtype=images.dtype), invalid_indices \ No newline at end of file diff --git a/utils/vis_utils.py b/utils/vis_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e25a788a9010d002d622104bf6b7a3e05c375749 --- /dev/null +++ b/utils/vis_utils.py @@ -0,0 +1,58 @@ +import textwrap +from typing import List, Tuple, Optional + +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +LINE_WIDTH = 20 + + +def add_text_to_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0), + min_lines: Optional[int] = None, add_below: bool = True): + import textwrap + lines = textwrap.wrap(text, width=LINE_WIDTH) + if min_lines is not None and len(lines) < min_lines: + if add_below: + lines += [''] * (min_lines - len(lines)) + else: + lines = [''] * (min_lines - len(lines)) + lines + h, w, c = image.shape + offset = int(h * .12) + img = np.ones((h + offset * len(lines), w, c), dtype=np.uint8) * 255 + font_size = int(offset * .8) + + try: + font = ImageFont.truetype("assets/OpenSans-Regular.ttf", font_size) + textsize = font.getbbox(text) + y_offset = (offset - textsize[3]) // 2 + except: + font = ImageFont.load_default() + y_offset = offset // 2 + + if add_below: + img[:h] = image + else: + img[-h:] = image + img = Image.fromarray(img) + draw = ImageDraw.Draw(img) + for i, line in enumerate(lines): + line_size = font.getbbox(line) + text_x = (w - line_size[2]) // 2 + if add_below: + draw.text((text_x, h + y_offset + offset * i), line, font=font, fill=text_color) + else: + draw.text((text_x, 0 + y_offset + offset * i), line, font=font, fill=text_color) + return np.array(img) + + +def create_table_plot(titles: List[str], images: List[Image.Image], captions: List[str]) -> Image.Image: + title_max_lines = np.max([len(textwrap.wrap(text, width=LINE_WIDTH)) for text in titles]) + caption_max_lines = np.max([len(textwrap.wrap(text, width=LINE_WIDTH)) for text in captions]) + out_images = [] + for i in range(len(images)): + im = np.array(images[i]) + im = add_text_to_image(im, titles[i], add_below=False, min_lines=title_max_lines) + im = add_text_to_image(im, captions[i], add_below=True, min_lines=caption_max_lines) + out_images.append(im) + image = Image.fromarray(np.concatenate(out_images, axis=1)) + return image