diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..4fc94f3541da4fc063bd9b4574f23fbb8afebb31 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/sculpture.png filter=lfs diff=lfs merge=lfs -text
+assets/teaser_figure.png filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..f49a4e16e68b128803cc2dcea614603632b04eac
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md
new file mode 100644
index 0000000000000000000000000000000000000000..49b240877d130b0ac8144fd91b08054efb6e78b6
--- /dev/null
+++ b/ORIGINAL_README.md
@@ -0,0 +1,191 @@
+
+
InstantIR: Blind Image Restoration withInstant Generative Reference
+
+[**Jen-Yuan Huang**](https://jy-joy.github.io)
1 2, [**Haofan Wang**](https://haofanwang.github.io/)
2, [**Qixun Wang**](https://github.com/wangqixun)
2, [**Xu Bai**](https://huggingface.co./baymin0220)
2, Hao Ai
2, Peng Xing
2, [**Jen-Tse Huang**](https://penguinnnnn.github.io)
3
+
+
1Peking University ยท
2InstantX Team ยท
3The Chinese University of Hong Kong
+
+
+
+
+
+
+
+
+
+
+
+
+**InstantIR** is a novel single-image restoration model designed to resurrect your damaged images, delivering extrem-quality yet realistic details. You can further boost **InstantIR** performance with additional text prompts, even achieve customized editing!
+
+
+
+
+
+
+## ๐ข News
+- **11/03/2024** ๐ฅ We provide a Gradio launching script for InstantIR, you can now deploy it on your local machine!
+- **11/02/2024** ๐ฅ InstantIR is now compatitble with ๐งจ `diffusers`, you can utilize features from this fascinating package!
+- **10/15/2024** ๐ฅ Code and model released!
+
+## ๐ TODOs:
+- [ ] Launch online demo
+- [x] Remove dependency on local `diffusers`
+- [x] Gradio launching script
+
+## โจ Usage
+
+
+### Quick start
+#### 1. Clone this repo and setting up environment
+```sh
+git clone https://github.com/JY-Joy/InstantIR.git
+cd InstantIR
+conda create -n instantir python=3.9 -y
+conda activate instantir
+pip install -r requirements.txt
+```
+
+#### 2. Download pre-trained models
+
+InstantIR is built on SDXL and DINOv2. You can download them either directly from ๐ค huggingface or using Python package.
+
+| ๐ค link | Python command
+| :--- | :----------
+|[SDXL](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0) | `hf_hub_download(repo_id="stabilityai/stable-diffusion-xl-base-1.0")`
+|[facebook/dinov2-large](https://huggingface.co./facebook/dinov2-large) | `hf_hub_download(repo_id="facebook/dinov2-large")`
+|[InstantX/InstantIR](https://huggingface.co./InstantX/InstantIR) | `hf_hub_download(repo_id="InstantX/InstantIR")`
+
+Note: Make sure to import the package first with `from huggingface_hub import hf_hub_download` if you are using Python script.
+
+#### 3. Inference
+
+You can run InstantIR inference using `infer.sh` with the following arguments specified.
+
+```sh
+infer.sh \
+ --sdxl_path \
+ --vision_encoder_path \
+ --instantir_path \
+ --test_path \
+ --out_path
+```
+
+See `infer.py` for more config options.
+
+#### 4. Using tips
+
+InstantIR is powerful, but with your help it can do better. InstantIR's flexible pipeline makes it tunable to a large extent. Here are some tips we found particularly useful for various cases you may encounter:
+- **Over-smoothing**: reduce `--cfg` to 3.0๏ฝ5.0. Higher CFG scales can sometimes rigid lines or lack of details.
+- **Low fidelity**: set `--preview_start` to 0.1~0.4 to preserve fidelity from inputs. The previewer can yield misleading references when input latent is too noisy. In such cases, we suggest to disable the previewer at early timesteps.
+- **Local distortions**: set `--creative_start` to 0.6~0.8. This will let InstantIR render freely in the late diffusion process, where the high-frequency details are generated. Smaller `--creative_start` spares more spaces for creative restoration, but will diminish fidelity.
+- **Faster inference**: higher `--preview_start` and lower `--creative_start` can both reduce computational costs and accelerate InstantIR inference.
+
+> [!CAUTION]
+> These features are training-free and thus experimental. If you would like to try, we suggest to tune these parameters case-by-case.
+
+### Use InstantIR with diffusers ๐งจ
+
+InstantIR is fully compatible with `diffusers` and is supported by all those powerful features in this package. You can directly load InstantIR via `diffusers` snippet:
+
+```py
+# !pip install diffusers opencv-python transformers accelerate
+import torch
+from PIL import Image
+
+from diffusers import DDPMScheduler
+from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
+
+from module.ip_adapter.utils import load_adapter_to_pipe
+from pipelines.sdxl_instantir import InstantIRPipeline
+
+# suppose you have InstantIR weights under ./models
+instantir_path = f'./models'
+
+# load pretrained models
+pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
+
+# load adapter
+load_adapter_to_pipe(
+ pipe,
+ f"{instantir_path}/adapter.pt",
+ image_encoder_or_path = 'facebook/dinov2-large',
+)
+
+# load previewer lora
+pipe.prepare_previewers(instantir_path)
+pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
+lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
+
+# load aggregator weights
+pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt")
+pipe.aggregator.load_state_dict(pretrained_state_dict)
+
+# send to GPU and fp16
+pipe.to(device='cuda', dtype=torch.float16)
+pipe.aggregator.to(device='cuda', dtype=torch.float16)
+```
+
+Then, you just need to call the `pipe` and InstantIR will handle your image!
+
+```py
+# load a broken image
+low_quality_image = Image.open('./assets/sculpture.png').convert("RGB")
+
+# InstantIR restoration
+image = pipe(
+ image=low_quality_image,
+ previewer_scheduler=lcm_scheduler,
+).images[0]
+```
+
+### Deploy local gradio demo
+
+We provide a python script to launch a local gradio demo of InstantIR, with basic and some advanced features implemented. Start by running the following command in your terminal:
+
+```sh
+INSTANTIR_PATH= python gradio_demo/app.py
+```
+
+Then, visit your local demo via your browser at `http://localhost:7860`.
+
+
+## โ๏ธ Training
+
+### Prepare data
+
+InstantIR is trained on [DIV2K](https://www.kaggle.com/datasets/joe1995/div2k-dataset), [Flickr2K](https://www.kaggle.com/datasets/daehoyang/flickr2k), [LSDIR](https://data.vision.ee.ethz.ch/yawli/index.html) and [FFHQ](https://www.kaggle.com/datasets/rahulbhalley/ffhq-1024x1024). We adopt dataset weighting to balance the distribution. You can config their weights in ```config_files/IR_dataset.yaml```. Download these training sets and put them under a same directory, which will be used in the following training configurations.
+
+### Two-stage training
+As described in our paper, the training of InstantIR is conducted in two stages. We provide corresponding `.sh` training scripts for each stage. Make sure you have the following arguments adapted to your own use case:
+
+| Argument | Value
+| :--- | :----------
+| `--pretrained_model_name_or_path` | path to your SDXL folder
+| `--feature_extractor_path` | path to your DINOv2 folder
+| `--train_data_dir` | your training data directory
+| `--output_dir` | path to save model weights
+| `--logging_dir` | path to save logs
+| `` | number of available GPUs
+
+Other training hyperparameters we used in our experiments are provided in the corresponding `.sh` scripts. You can tune them according to your own needs.
+
+## ๐ Acknowledgment
+Our work is sponsored by [HuggingFace](https://huggingface.co.) and [fal.ai](https://fal.ai).
+
+## ๐ Citation
+
+If InstantIR is helpful to your work, please cite our paper via:
+
+```
+@article{huang2024instantir,
+ title={InstantIR: Blind Image Restoration with Instant Generative Reference},
+ author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse},
+ journal={arXiv preprint arXiv:2410.06551},
+ year={2024}
+}
+```
diff --git a/assets/.DS_Store b/assets/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6
Binary files /dev/null and b/assets/.DS_Store differ
diff --git a/assets/Konan.png b/assets/Konan.png
new file mode 100644
index 0000000000000000000000000000000000000000..24b44e20ff41032c4037790f88b7ca20f1e0adb6
Binary files /dev/null and b/assets/Konan.png differ
diff --git a/assets/Naruto.png b/assets/Naruto.png
new file mode 100644
index 0000000000000000000000000000000000000000..7eb3836d04f19864147b7d761c68acd66ca4b0a2
Binary files /dev/null and b/assets/Naruto.png differ
diff --git a/assets/cottage.png b/assets/cottage.png
new file mode 100644
index 0000000000000000000000000000000000000000..dcefa84eed0fac611b273b1efa58af5bd52f1e00
Binary files /dev/null and b/assets/cottage.png differ
diff --git a/assets/dog.png b/assets/dog.png
new file mode 100644
index 0000000000000000000000000000000000000000..58125cfdf3c8ee64c1e3930d7505f2c74160df80
Binary files /dev/null and b/assets/dog.png differ
diff --git a/assets/lady.png b/assets/lady.png
new file mode 100644
index 0000000000000000000000000000000000000000..a3df98b191dcaf2793d9364f1f525acbce165d7b
Binary files /dev/null and b/assets/lady.png differ
diff --git a/assets/man.png b/assets/man.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a738b18f659ac7a23f03c3e3388f3f450382b2e
Binary files /dev/null and b/assets/man.png differ
diff --git a/assets/panda.png b/assets/panda.png
new file mode 100644
index 0000000000000000000000000000000000000000..f7c47480bc007be8a130dc4f24e13201373f868f
Binary files /dev/null and b/assets/panda.png differ
diff --git a/assets/sculpture.png b/assets/sculpture.png
new file mode 100644
index 0000000000000000000000000000000000000000..aa92a295a0884976629cd6796ddbe3d9d431827c
--- /dev/null
+++ b/assets/sculpture.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c4af7c3dc545d2f48b0ac2afef69bd7b1f0489ced7ea452d92f69ff5a9d4019
+size 1200996
diff --git a/assets/teaser_figure.png b/assets/teaser_figure.png
new file mode 100644
index 0000000000000000000000000000000000000000..cbe94f8394b912afd3e639ff49646e479583af8d
--- /dev/null
+++ b/assets/teaser_figure.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a9c7e8e59af17516d11e21c5bc56b48824a3875c81e4afb181a5c3facc217d08
+size 16937178
diff --git a/config_files/IR_dataset.yaml b/config_files/IR_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..acfaa864ca8674082d74d268ffe266095608b6d6
--- /dev/null
+++ b/config_files/IR_dataset.yaml
@@ -0,0 +1,9 @@
+datasets:
+ - dataset_folder: 'ffhq'
+ dataset_weight: 0.1
+ - dataset_folder: 'DIV2K'
+ dataset_weight: 0.3
+ - dataset_folder: 'LSDIR'
+ dataset_weight: 0.3
+ - dataset_folder: 'Flickr2K'
+ dataset_weight: 0.1
diff --git a/config_files/losses.yaml b/config_files/losses.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a650909b1bae32cbbb7b82817a04e1750850d7e
--- /dev/null
+++ b/config_files/losses.yaml
@@ -0,0 +1,19 @@
+diffusion_losses:
+- name: L2Loss
+ weight: 1
+lcm_losses:
+- name: HuberLoss
+ weight: 1
+# - name: DINOLoss
+# weight: 1e-3
+# - name: L2Loss
+# weight: 5e-2
+# - name: LPIPSLoss
+# weight: 1e-3
+# - name: DreamSIMLoss
+# weight: 1e-3
+# - name: IDLoss
+# weight: 1e-3
+# visualize_every_k: 50
+# init_params:
+# pretrained_arcface_path: /home/dcor/orlichter/consistency_encoder_private/pretrained_models/model_ir_se50.pth
\ No newline at end of file
diff --git a/config_files/val_dataset.yaml b/config_files/val_dataset.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9674bb2f16ae220d6749f7294e1c05a5f0b453e6
--- /dev/null
+++ b/config_files/val_dataset.yaml
@@ -0,0 +1,7 @@
+datasets:
+ - dataset_folder: 'ffhq'
+ dataset_weight: 0.1
+ - dataset_folder: 'DIV2K'
+ dataset_weight: 0.45
+ - dataset_folder: 'LSDIR'
+ dataset_weight: 0.45
diff --git a/data/data_config.py b/data/data_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2536debf191ecbbe05eeebb1778141ed73456b47
--- /dev/null
+++ b/data/data_config.py
@@ -0,0 +1,14 @@
+from dataclasses import dataclass, field
+from typing import Optional, List
+
+
+@dataclass
+class SingleDataConfig:
+ dataset_folder: str
+ imagefolder: bool = True
+ dataset_weight: float = 1.0 # Not used yet
+
+@dataclass
+class DataConfig:
+ datasets: List[SingleDataConfig]
+ val_dataset: Optional[SingleDataConfig] = None
diff --git a/data/dataset.py b/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..17d84cdb3571b0f9b5de9cf6ee9a9b36d5723e6f
--- /dev/null
+++ b/data/dataset.py
@@ -0,0 +1,202 @@
+from pathlib import Path
+from typing import Optional
+
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torchvision import transforms
+import json
+import random
+from facenet_pytorch import MTCNN
+import torch
+
+from utils.utils import extract_faces_and_landmarks, REFERNCE_FACIAL_POINTS_RELATIVE
+
+def load_image(image_path: str) -> Image:
+ image = Image.open(image_path)
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ return image
+
+
+class ImageDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ metadata_path: Optional[str] = None,
+ prompt_in_filename=False,
+ use_only_vanilla_for_encoder=False,
+ concept_placeholder='a face',
+ size=1024,
+ center_crop=False,
+ aug_images=False,
+ use_only_decoder_prompts=False,
+ crop_head_for_encoder_image=False,
+ random_target_prob=0.0,
+ ):
+ self.mtcnn = MTCNN(device='cuda:0')
+ self.mtcnn.forward = self.mtcnn.detect
+ resize_factor = 1.3
+ self.resized_reference_points = REFERNCE_FACIAL_POINTS_RELATIVE / resize_factor + (resize_factor - 1) / (2 * resize_factor)
+ self.size = size
+ self.center_crop = center_crop
+ self.concept_placeholder = concept_placeholder
+ self.prompt_in_filename = prompt_in_filename
+ self.aug_images = aug_images
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.name_to_label = None
+ self.crop_head_for_encoder_image = crop_head_for_encoder_image
+ self.random_target_prob = random_target_prob
+
+ self.use_only_decoder_prompts = use_only_decoder_prompts
+
+ self.instance_data_root = Path(instance_data_root)
+
+ if not self.instance_data_root.exists():
+ raise ValueError(f"Instance images root {self.instance_data_root} doesn't exist.")
+
+ if metadata_path is not None:
+ with open(metadata_path, 'r') as f:
+ self.name_to_label = json.load(f) # dict of filename: label
+ # Create a reversed mapping
+ self.label_to_names = {}
+ for name, label in self.name_to_label.items():
+ if use_only_vanilla_for_encoder and 'vanilla' not in name:
+ continue
+ if label not in self.label_to_names:
+ self.label_to_names[label] = []
+ self.label_to_names[label].append(name)
+ self.all_paths = [self.instance_data_root / filename for filename in self.name_to_label.keys()]
+
+ # Verify all paths exist
+ n_all_paths = len(self.all_paths)
+ self.all_paths = [path for path in self.all_paths if path.exists()]
+ print(f'Found {len(self.all_paths)} out of {n_all_paths} paths.')
+ else:
+ self.all_paths = [path for path in list(Path(instance_data_root).glob('**/*')) if
+ path.suffix.lower() in [".png", ".jpg", ".jpeg"]]
+ # Sort by name so that order for validation remains the same across runs
+ self.all_paths = sorted(self.all_paths, key=lambda x: x.stem)
+
+ self.custom_instance_prompts = None
+
+ self._length = len(self.all_paths)
+
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ if self.prompt_in_filename:
+ self.prompts_set = set([self._path_to_prompt(path) for path in self.all_paths])
+ else:
+ self.prompts_set = set([self.instance_prompt])
+
+ if self.aug_images:
+ self.aug_transforms = transforms.Compose(
+ [
+ transforms.RandomResizedCrop(size, scale=(0.8, 1.0), ratio=(1.0, 1.0)),
+ transforms.RandomHorizontalFlip(p=0.5)
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def _path_to_prompt(self, path):
+ # Remove the extension and seed
+ split_path = path.stem.split('_')
+ while split_path[-1].isnumeric():
+ split_path = split_path[:-1]
+
+ prompt = ' '.join(split_path)
+ # Replace placeholder in prompt with training placeholder
+ prompt = prompt.replace('conceptname', self.concept_placeholder)
+ return prompt
+
+ def __getitem__(self, index):
+ example = {}
+ instance_path = self.all_paths[index]
+ instance_image = load_image(instance_path)
+ example["instance_images"] = self.image_transforms(instance_image)
+ if self.prompt_in_filename:
+ example["instance_prompt"] = self._path_to_prompt(instance_path)
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.name_to_label is None:
+ # If no labels, simply take the same image but with different augmentation
+ example["encoder_images"] = self.aug_transforms(example["instance_images"]) if self.aug_images else example["instance_images"]
+ example["encoder_prompt"] = example["instance_prompt"]
+ else:
+ # Randomly select another image with the same label
+ instance_name = str(instance_path.relative_to(self.instance_data_root))
+ instance_label = self.name_to_label[instance_name]
+ label_set = set(self.label_to_names[instance_label])
+ if len(label_set) == 1:
+ # We are not supposed to have only one image per label, but just in case
+ encoder_image_name = instance_name
+ print(f'WARNING: Only one image for label {instance_label}.')
+ else:
+ encoder_image_name = random.choice(list(label_set - {instance_name}))
+ encoder_image = load_image(self.instance_data_root / encoder_image_name)
+ example["encoder_images"] = self.image_transforms(encoder_image)
+
+ if self.prompt_in_filename:
+ example["encoder_prompt"] = self._path_to_prompt(self.instance_data_root / encoder_image_name)
+ else:
+ example["encoder_prompt"] = self.instance_prompt
+
+ if self.crop_head_for_encoder_image:
+ example["encoder_images"] = extract_faces_and_landmarks(example["encoder_images"][None], self.size, self.mtcnn, self.resized_reference_points)[0][0]
+ example["encoder_prompt"] = example["encoder_prompt"].format(placeholder="")
+ example["instance_prompt"] = example["instance_prompt"].format(placeholder="")
+
+ if random.random() < self.random_target_prob:
+ random_path = random.choice(self.all_paths)
+
+ random_image = load_image(random_path)
+ example["instance_images"] = self.image_transforms(random_image)
+ if self.prompt_in_filename:
+ example["instance_prompt"] = self._path_to_prompt(random_path)
+
+
+ if self.use_only_decoder_prompts:
+ example["encoder_prompt"] = example["instance_prompt"]
+
+ return example
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ encoder_pixel_values = [example["encoder_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+ encoder_prompts = [example["encoder_prompt"] for example in examples]
+
+ if with_prior_preservation:
+ raise NotImplementedError("Prior preservation not implemented.")
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ encoder_pixel_values = torch.stack(encoder_pixel_values)
+ encoder_pixel_values = encoder_pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "encoder_pixel_values": encoder_pixel_values,
+ "prompts": prompts, "encoder_prompts": encoder_prompts}
+ return batch
diff --git a/docs/.DS_Store b/docs/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..ecb218a51788f964c73e12aac69933b3b8193ec9
Binary files /dev/null and b/docs/.DS_Store differ
diff --git a/docs/static/.DS_Store b/docs/static/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..f99ae43bb14d60ad97b7e197cf9798c9be86ac69
Binary files /dev/null and b/docs/static/.DS_Store differ
diff --git a/environment.yaml b/environment.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1ce87f39711264bd96d97c60530d8ed707bb80ea
--- /dev/null
+++ b/environment.yaml
@@ -0,0 +1,37 @@
+name: instantir
+channels:
+ - pytorch
+ - nvidia
+ - conda-forge
+ - defaults
+dependencies:
+ - numpy
+ - pandas
+ - pillow
+ - pip
+ - python=3.9.15
+ - pytorch=2.2.2
+ - pytorch-lightning=1.6.5
+ - pytorch-cuda=12.1
+ - setuptools
+ - torchaudio=2.2.2
+ - torchmetrics
+ - torchvision=0.17.2
+ - tqdm
+ - pip:
+ - accelerate==0.25.0
+ - diffusers==0.24.0
+ - einops
+ - open-clip-torch
+ - opencv-python==4.8.1.78
+ - tokenizers
+ - transformers==4.36.2
+ - kornia
+ - facenet_pytorch
+ - lpips
+ - dreamsim
+ - pyrallis
+ - wandb
+ - insightface
+ - onnxruntime==1.17.0
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
\ No newline at end of file
diff --git a/gradio_demo/app.py b/gradio_demo/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5301676376eb1e1c2fbb8545651ca5fc38d3947
--- /dev/null
+++ b/gradio_demo/app.py
@@ -0,0 +1,250 @@
+import os
+import sys
+sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
+
+import torch
+import numpy as np
+import gradio as gr
+from PIL import Image
+
+from diffusers import DDPMScheduler
+from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
+
+from module.ip_adapter.utils import load_adapter_to_pipe
+from pipelines.sdxl_instantir import InstantIRPipeline
+
+def resize_img(input_image, max_side=1280, min_side=1024, size=None,
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
+
+ w, h = input_image.size
+ if size is not None:
+ w_resize_new, h_resize_new = size
+ else:
+ # ratio = min_side / min(h, w)
+ # w, h = round(ratio*w), round(ratio*h)
+ ratio = max_side / max(h, w)
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
+
+ if pad_to_max_side:
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
+ offset_x = (max_side - w_resize_new) // 2
+ offset_y = (max_side - h_resize_new) // 2
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
+ input_image = Image.fromarray(res)
+ return input_image
+
+instantir_path = os.environ['INSTANTIR_PATH']
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
+dinov2_repo_id = "facebook/dinov2-large"
+lcm_repo_id = "latent-consistency/lcm-lora-sdxl"
+
+if torch.cuda.is_available():
+ torch_dtype = torch.float16
+else:
+ torch_dtype = torch.float32
+
+# Load pretrained models.
+print("Initializing pipeline...")
+pipe = InstantIRPipeline.from_pretrained(
+ sdxl_repo_id,
+ torch_dtype=torch_dtype,
+)
+
+# Image prompt projector.
+print("Loading LQ-Adapter...")
+load_adapter_to_pipe(
+ pipe,
+ f"{instantir_path}/adapter.pt",
+ dinov2_repo_id,
+)
+
+# Prepare previewer
+lora_alpha = pipe.prepare_previewers(instantir_path)
+print(f"use lora alpha {lora_alpha}")
+lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
+print(f"use lora alpha {lora_alpha}")
+pipe.to(device=device, dtype=torch_dtype)
+pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
+lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
+
+# Load weights.
+print("Loading checkpoint...")
+aggregator_state_dict = torch.load(
+ f"{instantir_path}/aggregator.pt",
+ map_location="cpu"
+)
+pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
+pipe.aggregator.to(device=device, dtype=torch_dtype)
+
+MAX_SEED = np.iinfo(np.int32).max
+MAX_IMAGE_SIZE = 1024
+
+PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
+ultra HD, extreme meticulous detailing, skin pore detailing, \
+hyper sharpness, perfect without deformations, \
+taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. "
+
+NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \
+sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
+dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
+watermark, signature, jpeg artifacts, deformed, lowres"
+
+def unpack_pipe_out(preview_row, index):
+ return preview_row[index][0]
+
+def dynamic_preview_slider(sampling_steps):
+ print(sampling_steps)
+ return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1)
+
+def dynamic_guidance_slider(sampling_steps):
+ return gr.Slider(label="Start Free Rendering", value=sampling_steps, minimum=0, maximum=sampling_steps, step=1)
+
+def show_final_preview(preview_row):
+ return preview_row[-1][0]
+
+# @spaces.GPU #[uncomment to use ZeroGPU]
+@torch.no_grad()
+def instantir_restore(
+ lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0,
+ creative_restoration=False, seed=3407, height=1024, width=1024, preview_start=0.0):
+ if creative_restoration:
+ if "lcm" not in pipe.unet.active_adapters():
+ pipe.unet.set_adapter('lcm')
+ else:
+ if "previewer" not in pipe.unet.active_adapters():
+ pipe.unet.set_adapter('previewer')
+
+ if isinstance(guidance_end, int):
+ guidance_end = guidance_end / steps
+ elif guidance_end > 1.0:
+ guidance_end = guidance_end / steps
+ if isinstance(preview_start, int):
+ preview_start = preview_start / steps
+ elif preview_start > 1.0:
+ preview_start = preview_start / steps
+ lq = [resize_img(lq.convert("RGB"), size=(width, height))]
+ generator = torch.Generator(device=device).manual_seed(seed)
+ timesteps = [
+ i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
+ ]
+ timesteps = timesteps[::-1]
+
+ prompt = PROMPT if len(prompt)==0 else prompt
+ neg_prompt = NEG_PROMPT
+
+ out = pipe(
+ prompt=[prompt]*len(lq),
+ image=lq,
+ num_inference_steps=steps,
+ generator=generator,
+ timesteps=timesteps,
+ negative_prompt=[neg_prompt]*len(lq),
+ guidance_scale=cfg_scale,
+ control_guidance_end=guidance_end,
+ preview_start=preview_start,
+ previewer_scheduler=lcm_scheduler,
+ return_dict=False,
+ save_preview_row=True,
+ )
+ for i, preview_img in enumerate(out[1]):
+ preview_img.append(f"preview_{i}")
+ return out[0][0], out[1]
+
+examples = [
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
+ "An astronaut riding a green horse",
+ "A delicious ceviche cheesecake slice",
+]
+
+css="""
+#col-container {
+ margin: 0 auto;
+ max-width: 640px;
+}
+"""
+
+with gr.Blocks() as demo:
+ gr.Markdown(
+ """
+ # InstantIR: Blind Image Restoration with Instant Generative Reference.
+
+ ### **Official ๐ค Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).**
+ ### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!**
+ ## Basic usage: revitalize your image
+ 1. Upload an image you want to restore;
+ 2. Optionally, tune the `Steps` `CFG Scale` parameters. Typically higher steps lead to better results, but less than 50 is recommended for efficiency;
+ 3. Click `InstantIR magic!`.
+ """)
+ with gr.Row():
+ lq_img = gr.Image(label="Low-quality image", type="pil")
+ with gr.Column():
+ with gr.Row():
+ steps = gr.Number(label="Steps", value=30, step=1)
+ cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1)
+ with gr.Row():
+ height = gr.Number(label="Height", value=1024, step=1)
+ weight = gr.Number(label="Weight", value=1024, step=1)
+ seed = gr.Number(label="Seed", value=42, step=1)
+ # guidance_start = gr.Slider(label="Guidance Start", value=1.0, minimum=0.0, maximum=1.0, step=0.05)
+ guidance_end = gr.Slider(label="Start Free Rendering", value=30, minimum=0, maximum=30, step=1)
+ preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
+ prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
+ mode = gr.Checkbox(label="Creative Restoration", value=False)
+ with gr.Row():
+ with gr.Row():
+ restore_btn = gr.Button("InstantIR magic!")
+ clear_btn = gr.ClearButton()
+ index = gr.Slider(label="Restoration Previews", value=29, minimum=0, maximum=29, step=1)
+ with gr.Row():
+ output = gr.Image(label="InstantIR restored", type="pil")
+ preview = gr.Image(label="Preview", type="pil")
+ pipe_out = gr.Gallery(visible=False)
+ clear_btn.add([lq_img, output, preview])
+ restore_btn.click(
+ instantir_restore, inputs=[
+ lq_img, prompt, steps, cfg_scale, guidance_end,
+ mode, seed, height, weight, preview_start,
+ ],
+ outputs=[output, pipe_out], api_name="InstantIR"
+ )
+ steps.change(dynamic_guidance_slider, inputs=steps, outputs=guidance_end)
+ output.change(dynamic_preview_slider, inputs=steps, outputs=index)
+ index.release(unpack_pipe_out, inputs=[pipe_out, index], outputs=preview)
+ output.change(show_final_preview, inputs=pipe_out, outputs=preview)
+ gr.Markdown(
+ """
+ ## Advance usage:
+ ### Browse restoration variants:
+ 1. After InstantIR processing, drag the `Restoration Previews` slider to explore other in-progress versions;
+ 2. If you like one of them, set the `Start Free Rendering` slider to the same value to get a more refined result.
+ ### Creative restoration:
+ 1. Check the `Creative Restoration` checkbox;
+ 2. Input your text prompts in the `Restoration prompts` textbox;
+ 3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation.
+
+ ## Examples
+ Here are some examplar usage of InstantIR:
+ """)
+ # examples = gr.Gallery(label="Examples")
+
+ gr.Markdown(
+ """
+ ## Citation
+ If InstantIR is helpful to your work, please cite our paper via:
+
+ ```
+ @article{huang2024instantir,
+ title={InstantIR: Blind Image Restoration with Instant Generative Reference},
+ author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse},
+ journal={arXiv preprint arXiv:2410.06551},
+ year={2024}
+ }
+ ```
+ """)
+
+demo.queue().launch()
\ No newline at end of file
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa021623a6a6a0e90acc271483f7413dfe24bd60
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,381 @@
+import os
+import argparse
+import numpy as np
+import torch
+
+from PIL import Image
+from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
+
+from diffusers import DDPMScheduler
+
+from module.ip_adapter.utils import load_adapter_to_pipe
+from pipelines.sdxl_instantir import InstantIRPipeline
+
+
+def name_unet_submodules(unet):
+ def recursive_find_module(name, module, end=False):
+ if end:
+ for sub_name, sub_module in module.named_children():
+ sub_module.full_name = f"{name}.{sub_name}"
+ return
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
+ elif "resnets" in name: return
+ for sub_name, sub_module in module.named_children():
+ end = True if sub_name == "transformer_blocks" else False
+ recursive_find_module(f"{name}.{sub_name}", sub_module, end)
+
+ for name, module in unet.named_children():
+ recursive_find_module(name, module)
+
+
+def resize_img(input_image, max_side=1280, min_side=1024, size=None,
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
+
+ w, h = input_image.size
+ if size is not None:
+ w_resize_new, h_resize_new = size
+ else:
+ # ratio = min_side / min(h, w)
+ # w, h = round(ratio*w), round(ratio*h)
+ ratio = max_side / max(h, w)
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
+
+ if pad_to_max_side:
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
+ offset_x = (max_side - w_resize_new) // 2
+ offset_y = (max_side - h_resize_new) // 2
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
+ input_image = Image.fromarray(res)
+ return input_image
+
+
+def tensor_to_pil(images):
+ """
+ Convert image tensor or a batch of image tensors to PIL image(s).
+ """
+ images = images.clamp(0, 1)
+ images_np = images.detach().cpu().numpy()
+ if images_np.ndim == 4:
+ images_np = np.transpose(images_np, (0, 2, 3, 1))
+ elif images_np.ndim == 3:
+ images_np = np.transpose(images_np, (1, 2, 0))
+ images_np = images_np[None, ...]
+ images_np = (images_np * 255).round().astype("uint8")
+ if images_np.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np]
+ else:
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np]
+
+ return pil_images
+
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+def main(args, device):
+
+ # Load pretrained models.
+ pipe = InstantIRPipeline.from_pretrained(
+ args.sdxl_path,
+ torch_dtype=torch.float16,
+ )
+
+ # Image prompt projector.
+ print("Loading LQ-Adapter...")
+ load_adapter_to_pipe(
+ pipe,
+ args.adapter_model_path if args.adapter_model_path is not None else os.path.join(args.instantir_path, 'adapter.pt'),
+ args.vision_encoder_path,
+ use_clip_encoder=args.use_clip_encoder,
+ )
+
+ # Prepare previewer
+ previewer_lora_path = args.previewer_lora_path if args.previewer_lora_path is not None else args.instantir_path
+ if previewer_lora_path is not None:
+ lora_alpha = pipe.prepare_previewers(previewer_lora_path)
+ print(f"use lora alpha {lora_alpha}")
+ pipe.to(device=device, dtype=torch.float16)
+ pipe.scheduler = DDPMScheduler.from_pretrained(args.sdxl_path, subfolder="scheduler")
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
+
+ # Load weights.
+ print("Loading checkpoint...")
+ pretrained_state_dict = torch.load(os.path.join(args.instantir_path, "aggregator.pt"), map_location="cpu")
+ pipe.aggregator.load_state_dict(pretrained_state_dict)
+ pipe.aggregator.to(device, dtype=torch.float16)
+
+ #################### Restoration ####################
+
+ post_fix = f"_{args.post_fix}" if args.post_fix else ""
+ os.makedirs(f"{args.out_path}/{post_fix}", exist_ok=True)
+
+ processed_imgs = os.listdir(os.path.join(args.out_path, post_fix))
+ lq_files = []
+ lq_batch = []
+ if os.path.isfile(args.test_path):
+ all_inputs = [args.test_path.split("/")[-1]]
+ else:
+ all_inputs = os.listdir(args.test_path)
+ all_inputs.sort()
+ for file in all_inputs:
+ if file in processed_imgs:
+ print(f"Skip {file}")
+ continue
+ lq_batch.append(f"{file}")
+ if len(lq_batch) == args.batch_size:
+ lq_files.append(lq_batch)
+ lq_batch = []
+
+ if len(lq_batch) > 0:
+ lq_files.append(lq_batch)
+
+ for lq_batch in lq_files:
+ generator = torch.Generator(device=device).manual_seed(args.seed)
+ pil_lqs = [Image.open(os.path.join(args.test_path, file)) for file in lq_batch]
+ if args.width is None or args.height is None:
+ lq = [resize_img(pil_lq.convert("RGB"), size=None) for pil_lq in pil_lqs]
+ else:
+ lq = [resize_img(pil_lq.convert("RGB"), size=(args.width, args.height)) for pil_lq in pil_lqs]
+ timesteps = None
+ if args.denoising_start < 1000:
+ timesteps = [
+ i * (args.denoising_start//args.num_inference_steps) + pipe.scheduler.config.steps_offset for i in range(0, args.num_inference_steps)
+ ]
+ timesteps = timesteps[::-1]
+ pipe.scheduler.set_timesteps(args.num_inference_steps, device)
+ timesteps = pipe.scheduler.timesteps
+ if args.prompt is None or len(args.prompt) == 0:
+ prompt = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
+ ultra HD, extreme meticulous detailing, skin pore detailing, \
+ hyper sharpness, perfect without deformations, \
+ taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. "
+ else:
+ prompt = args.prompt
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+ prompt = prompt*len(lq)
+ if args.neg_prompt is None or len(args.neg_prompt) == 0:
+ neg_prompt = "blurry, out of focus, unclear, depth of field, over-smooth, \
+ sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
+ dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
+ watermark, signature, jpeg artifacts, deformed, lowres"
+ else:
+ neg_prompt = args.neg_prompt
+ if not isinstance(neg_prompt, list):
+ neg_prompt = [neg_prompt]
+ neg_prompt = neg_prompt*len(lq)
+ image = pipe(
+ prompt=prompt,
+ image=lq,
+ num_inference_steps=args.num_inference_steps,
+ generator=generator,
+ timesteps=timesteps,
+ negative_prompt=neg_prompt,
+ guidance_scale=args.cfg,
+ previewer_scheduler=lcm_scheduler,
+ preview_start=args.preview_start,
+ control_guidance_end=args.creative_start,
+ ).images
+
+ if args.save_preview_row:
+ for i, lcm_image in enumerate(image[1]):
+ lcm_image.save(f"./lcm/{i}.png")
+ for i, rec_image in enumerate(image):
+ rec_image.save(f"{args.out_path}/{post_fix}/{lq_batch[i]}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="InstantIR pipeline")
+ parser.add_argument(
+ "--sdxl_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--previewer_lora_path",
+ type=str,
+ default=None,
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--instantir_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained instantir model.",
+ )
+ parser.add_argument(
+ "--vision_encoder_path",
+ type=str,
+ default='/share/huangrenyuan/model_zoo/vis_backbone/dinov2_large',
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--adapter_model_path",
+ type=str,
+ default=None,
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--adapter_tokens",
+ type=int,
+ default=64,
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
+ )
+ parser.add_argument(
+ "--use_clip_encoder",
+ action="store_true",
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
+ )
+ parser.add_argument(
+ "--denoising_start",
+ type=int,
+ default=1000,
+ help="Diffusion start timestep."
+ )
+ parser.add_argument(
+ "--num_inference_steps",
+ type=int,
+ default=30,
+ help="Diffusion steps."
+ )
+ parser.add_argument(
+ "--creative_start",
+ type=float,
+ default=1.0,
+ help="Proportion of timesteps for creative restoration. 1.0 means no creative restoration while 0.0 means completely free rendering."
+ )
+ parser.add_argument(
+ "--preview_start",
+ type=float,
+ default=0.0,
+ help="Proportion of timesteps to stop previewing at the begining to enhance fidelity to input."
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=6,
+ help="Test batch size."
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=None,
+ help="Output image width."
+ )
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=None,
+ help="Output image height."
+ )
+ parser.add_argument(
+ "--cfg",
+ type=float,
+ default=7.0,
+ help="Scale of Classifier-Free-Guidance (CFG).",
+ )
+ parser.add_argument(
+ "--post_fix",
+ type=str,
+ default=None,
+ help="Subfolder name for restoration output under the output directory.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default='fp16',
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--save_preview_row",
+ action="store_true",
+ help="Whether or not to save the intermediate lcm outputs.",
+ )
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default='',
+ nargs="+",
+ help=(
+ "A set of prompts for creative restoration. Provide either a matching number of test images,"
+ " or a single prompt to be used with all inputs."
+ ),
+ )
+ parser.add_argument(
+ "--neg_prompt",
+ type=str,
+ default='',
+ nargs="+",
+ help=(
+ "A set of negative prompts for creative restoration. Provide either a matching number of test images,"
+ " or a single negative prompt to be used with all inputs."
+ ),
+ )
+ parser.add_argument(
+ "--test_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Test directory.",
+ )
+ parser.add_argument(
+ "--out_path",
+ type=str,
+ default="./output",
+ help="Output directory.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ args = parser.parse_args()
+ args.height = args.height or args.width
+ args.width = args.width or args.height
+ if args.height is not None and (args.width % 64 != 0 or args.height % 64 != 0):
+ raise ValueError("Image resolution must be divisible by 64.")
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ main(args, device)
\ No newline at end of file
diff --git a/infer.sh b/infer.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9d7e54337d06dd33b857ce4149d164e8ac8cdb3a
--- /dev/null
+++ b/infer.sh
@@ -0,0 +1,6 @@
+python infer.py \
+ --sdxl_path path/to/sdxl \
+ --vision_encoder_path path/to/dinov2_large \
+ --instantir_path path/to/instantir \
+ --test_path path/to/input \
+ --out_path path/to/output
\ No newline at end of file
diff --git a/losses/loss_config.py b/losses/loss_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..64a384818982b598f484eeaf908ec91c1782782f
--- /dev/null
+++ b/losses/loss_config.py
@@ -0,0 +1,15 @@
+from dataclasses import dataclass, field
+from typing import List
+
+@dataclass
+class SingleLossConfig:
+ name: str
+ weight: float = 1.
+ init_params: dict = field(default_factory=dict)
+ visualize_every_k: int = -1
+
+
+@dataclass
+class LossesConfig:
+ diffusion_losses: List[SingleLossConfig]
+ lcm_losses: List[SingleLossConfig]
\ No newline at end of file
diff --git a/losses/losses.py b/losses/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..71de2791906d1f68871a16974bb2cd9d7fabbf37
--- /dev/null
+++ b/losses/losses.py
@@ -0,0 +1,465 @@
+import torch
+import wandb
+import cv2
+import torch.nn.functional as F
+import numpy as np
+from facenet_pytorch import MTCNN
+from torchvision import transforms
+from dreamsim import dreamsim
+from einops import rearrange
+import kornia.augmentation as K
+import lpips
+
+from pretrained_models.arcface import Backbone
+from utils.vis_utils import add_text_to_image
+from utils.utils import extract_faces_and_landmarks
+import clip
+
+
+class Loss():
+ """
+ General purpose loss class.
+ Mainly handles dtype and visualize_every_k.
+ keeps current iteration of loss, mainly for visualization purposes.
+ """
+ def __init__(self, visualize_every_k=-1, dtype=torch.float32, accelerator=None, **kwargs):
+ self.visualize_every_k = visualize_every_k
+ self.iteration = -1
+ self.dtype=dtype
+ self.accelerator = accelerator
+
+ def __call__(self, **kwargs):
+ self.iteration += 1
+ return self.forward(**kwargs)
+
+
+class L1Loss(Loss):
+ """
+ Simple L1 loss between predicted_pixel_values and pixel_values
+
+ Args:
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
+ """
+ def forward(
+ self,
+ predict: torch.Tensor,
+ target: torch.Tensor,
+ **kwargs
+ ) -> torch.Tensor:
+ return F.l1_loss(predict, target, reduction="mean")
+
+
+class DreamSIMLoss(Loss):
+ """DreamSIM loss between predicted_pixel_values and pixel_values.
+ DreamSIM is similar to LPIPS (https://dreamsim-nights.github.io/) but is trained on more human defined similarity dataset
+ DreamSIM expects an RGB image of size 224x224 and values between 0 and 1. So we need to normalize the input images to 0-1 range and resize them to 224x224.
+ Args:
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
+ """
+ def __init__(self, device: str='cuda:0', **kwargs):
+ super().__init__(**kwargs)
+ self.model, _ = dreamsim(pretrained=True, device=device)
+ self.model.to(dtype=self.dtype, device=device)
+ self.model = self.accelerator.prepare(self.model)
+ self.transforms = transforms.Compose([
+ transforms.Lambda(lambda x: (x + 1) / 2),
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)])
+
+ def forward(
+ self,
+ predicted_pixel_values: torch.Tensor,
+ encoder_pixel_values: torch.Tensor,
+ **kwargs,
+ ) -> torch.Tensor:
+ predicted_pixel_values.to(dtype=self.dtype)
+ encoder_pixel_values.to(dtype=self.dtype)
+ return self.model(self.transforms(predicted_pixel_values), self.transforms(encoder_pixel_values)).mean()
+
+
+class LPIPSLoss(Loss):
+ """LPIPS loss between predicted_pixel_values and pixel_values.
+ Args:
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
+ """
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.model = lpips.LPIPS(net='vgg')
+ self.model.to(dtype=self.dtype, device=self.accelerator.device)
+ self.model = self.accelerator.prepare(self.model)
+
+ def forward(self, predict, target, **kwargs):
+ predict.to(dtype=self.dtype)
+ target.to(dtype=self.dtype)
+ return self.model(predict, target).mean()
+
+
+class LCMVisualization(Loss):
+ """Dummy loss used to visualize the LCM outputs
+ Args:
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
+ pixel_values (torch.Tensor): The input image to the decoder
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
+ """
+ def forward(
+ self,
+ predicted_pixel_values: torch.Tensor,
+ pixel_values: torch.Tensor,
+ encoder_pixel_values: torch.Tensor,
+ timesteps: torch.Tensor,
+ **kwargs,
+ ) -> None:
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
+ predicted_pixel_values = rearrange(predicted_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
+ pixel_values = rearrange(pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
+ encoder_pixel_values = rearrange(encoder_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
+ image = np.hstack([encoder_pixel_values, pixel_values, predicted_pixel_values])
+ for tracker in self.accelerator.trackers:
+ if tracker.name == 'wandb':
+ tracker.log({"TrainVisualization": wandb.Image(image, caption=f"Encoder Input Image, Decoder Input Image, Predicted LCM Image. Timesteps {timesteps.cpu().tolist()}")})
+ return torch.tensor(0.0)
+
+
+class L2Loss(Loss):
+ """
+ Regular diffusion loss between predicted noise and target noise.
+
+ Args:
+ predicted_noise (torch.Tensor): noise predicted by the diffusion model
+ target_noise (torch.Tensor): actual noise added to the image.
+ """
+ def forward(
+ self,
+ predict: torch.Tensor,
+ target: torch.Tensor,
+ weights: torch.Tensor = None,
+ **kwargs
+ ) -> torch.Tensor:
+ if weights is not None:
+ loss = (predict.float() - target.float()).pow(2) * weights
+ return loss.mean()
+ return F.mse_loss(predict.float(), target.float(), reduction="mean")
+
+
+class HuberLoss(Loss):
+ """Huber loss between predicted_pixel_values and pixel_values.
+ Args:
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
+ """
+ def __init__(self, huber_c=0.001, **kwargs):
+ super().__init__(**kwargs)
+ self.huber_c = huber_c
+
+ def forward(
+ self,
+ predict: torch.Tensor,
+ target: torch.Tensor,
+ weights: torch.Tensor = None,
+ **kwargs
+ ) -> torch.Tensor:
+ loss = torch.sqrt((predict.float() - target.float()) ** 2 + self.huber_c**2) - self.huber_c
+ if weights is not None:
+ return (loss * weights).mean()
+ return loss.mean()
+
+
+class WeightedNoiseLoss(Loss):
+ """
+ Weighted diffusion loss between predicted noise and target noise.
+
+ Args:
+ predicted_noise (torch.Tensor): noise predicted by the diffusion model
+ target_noise (torch.Tensor): actual noise added to the image.
+ loss_batch_weights (torch.Tensor): weighting for each batch item. Can be used to e.g. zero-out loss for InstantID training if keypoint extraction fails.
+ """
+ def forward(
+ self,
+ predict: torch.Tensor,
+ target: torch.Tensor,
+ weights,
+ **kwargs
+ ) -> torch.Tensor:
+ return F.mse_loss(predict.float() * weights, target.float() * weights, reduction="mean")
+
+
+class IDLoss(Loss):
+ """
+ Use pretrained facenet model to extract features from the face of the predicted image and target image.
+ Facenet expects 112x112 images, so we crop the face using MTCNN and resize it to 112x112.
+ Then we use the cosine similarity between the features to calculate the loss. (The cosine similarity is 1 - cosine distance).
+ Also notice that the outputs of facenet are normalized so the dot product is the same as cosine distance.
+ """
+ def __init__(self, pretrained_arcface_path: str, skip_not_found=True, **kwargs):
+ super().__init__(**kwargs)
+ assert pretrained_arcface_path is not None, "please pass `pretrained_arcface_path` in the losses config. You can download the pretrained model from "\
+ "https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing"
+ self.mtcnn = MTCNN(device=self.accelerator.device)
+ self.mtcnn.forward = self.mtcnn.detect
+ self.facenet_input_size = 112 # Has to be 112, can't find weights for 224 size.
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
+ self.facenet.load_state_dict(torch.load(pretrained_arcface_path))
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((self.facenet_input_size, self.facenet_input_size))
+ self.facenet.requires_grad_(False)
+ self.facenet.eval()
+ self.facenet.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
+ self.face_pool.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
+ self.visualization_resize = transforms.Resize((self.facenet_input_size, self.facenet_input_size), interpolation=transforms.InterpolationMode.BICUBIC)
+ self.reference_facial_points = np.array([[38.29459953, 51.69630051],
+ [72.53179932, 51.50139999],
+ [56.02519989, 71.73660278],
+ [41.54930115, 92.3655014],
+ [70.72990036, 92.20410156]
+ ]) # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112
+ self.facenet, self.face_pool, self.mtcnn = self.accelerator.prepare(self.facenet, self.face_pool, self.mtcnn)
+
+ self.skip_not_found = skip_not_found
+
+ def extract_feats(self, x: torch.Tensor):
+ """
+ Extract features from the face of the image using facenet model.
+ """
+ x = self.face_pool(x)
+ x_feats = self.facenet(x)
+
+ return x_feats
+
+ def forward(
+ self,
+ predicted_pixel_values: torch.Tensor,
+ encoder_pixel_values: torch.Tensor,
+ timesteps: torch.Tensor,
+ **kwargs
+ ):
+ encoder_pixel_values = encoder_pixel_values.to(dtype=self.dtype)
+ predicted_pixel_values = predicted_pixel_values.to(dtype=self.dtype)
+
+ predicted_pixel_values_face, predicted_invalid_indices = extract_faces_and_landmarks(predicted_pixel_values, mtcnn=self.mtcnn)
+ with torch.no_grad():
+ encoder_pixel_values_face, source_invalid_indices = extract_faces_and_landmarks(encoder_pixel_values, mtcnn=self.mtcnn)
+
+ if self.skip_not_found:
+ valid_indices = []
+ for i in range(predicted_pixel_values.shape[0]):
+ if i not in predicted_invalid_indices and i not in source_invalid_indices:
+ valid_indices.append(i)
+ else:
+ valid_indices = list(range(predicted_pixel_values))
+
+ valid_indices = torch.tensor(valid_indices).to(device=predicted_pixel_values.device)
+
+ if len(valid_indices) == 0:
+ loss = (predicted_pixel_values_face * 0.0).mean() # It's done this way so the `backwards` will delete the computation graph of the predicted_pixel_values.
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
+ self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
+ return loss
+
+ with torch.no_grad():
+ pixel_values_feats = self.extract_feats(encoder_pixel_values_face[valid_indices])
+
+ predicted_pixel_values_feats = self.extract_feats(predicted_pixel_values_face[valid_indices])
+ loss = 1 - torch.einsum("bi,bi->b", pixel_values_feats, predicted_pixel_values_feats)
+
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
+ self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
+ return loss.mean()
+
+ def visualize(
+ self,
+ predicted_pixel_values: torch.Tensor,
+ encoder_pixel_values: torch.Tensor,
+ predicted_pixel_values_face: torch.Tensor,
+ encoder_pixel_values_face: torch.Tensor,
+ timesteps: torch.Tensor,
+ valid_indices: torch.Tensor,
+ loss: torch.Tensor,
+ ) -> None:
+ small_predicted_pixel_values = (rearrange(self.visualization_resize(predicted_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy())
+ small_pixle_values = rearrange(self.visualization_resize(encoder_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy()
+ small_predicted_pixel_values_face = rearrange(self.visualization_resize(predicted_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
+ small_pixle_values_face = rearrange(self.visualization_resize(encoder_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
+
+ small_predicted_pixel_values = add_text_to_image(((small_predicted_pixel_values * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Images", add_below=False)
+ small_pixle_values = add_text_to_image(((small_pixle_values * 0.5 + 0.5) * 255).astype(np.uint8), "Target Images", add_below=False)
+ small_predicted_pixel_values_face = add_text_to_image(((small_predicted_pixel_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Faces", add_below=False)
+ small_pixle_values_face = add_text_to_image(((small_pixle_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Target Faces", add_below=False)
+
+
+ final_image = np.hstack([small_predicted_pixel_values, small_pixle_values, small_predicted_pixel_values_face, small_pixle_values_face])
+ for tracker in self.accelerator.trackers:
+ if tracker.name == 'wandb':
+ tracker.log({"IDLoss Visualization": wandb.Image(final_image, caption=f"loss: {loss.cpu().tolist()} timesteps: {timesteps.cpu().tolist()}, valid_indices: {valid_indices.cpu().tolist()}")})
+
+
+class ImageAugmentations(torch.nn.Module):
+ # Standard image augmentations used for CLIP loss to discourage adversarial outputs.
+ def __init__(self, output_size, augmentations_number, p=0.7):
+ super().__init__()
+ self.output_size = output_size
+ self.augmentations_number = augmentations_number
+
+ self.augmentations = torch.nn.Sequential(
+ K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore
+ K.RandomPerspective(0.7, p=p),
+ )
+
+ self.avg_pool = torch.nn.AdaptiveAvgPool2d((self.output_size, self.output_size))
+
+ self.device = None
+
+ def forward(self, input):
+ """Extents the input batch with augmentations
+ If the input is consists of images [I1, I2] the extended augmented output
+ will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...]
+ Args:
+ input ([type]): input batch of shape [batch, C, H, W]
+ Returns:
+ updated batch: of shape [batch * augmentations_number, C, H, W]
+ """
+ # We want to multiply the number of images in the batch in contrast to regular augmantations
+ # that do not change the number of samples in the batch)
+ resized_images = self.avg_pool(input)
+ resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1))
+
+ batch_size = input.shape[0]
+ # We want at least one non augmented image
+ non_augmented_batch = resized_images[:batch_size]
+ augmented_batch = self.augmentations(resized_images[batch_size:])
+ updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0)
+
+ return updated_batch
+
+
+class CLIPLoss(Loss):
+ def __init__(self, augmentations_number: int = 4, **kwargs):
+ super().__init__(**kwargs)
+
+ self.clip_model, clip_preprocess = clip.load("ViT-B/16", device=self.accelerator.device, jit=False)
+
+ self.clip_model.device = None
+
+ self.clip_model.eval().requires_grad_(False)
+
+ self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
+ clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
+ clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
+
+ self.clip_size = self.clip_model.visual.input_resolution
+
+ self.clip_normalize = transforms.Normalize(
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
+ )
+
+ self.image_augmentations = ImageAugmentations(output_size=self.clip_size,
+ augmentations_number=augmentations_number)
+
+ self.clip_model, self.image_augmentations = self.accelerator.prepare(self.clip_model, self.image_augmentations)
+
+ def forward(self, decoder_prompts, predicted_pixel_values: torch.Tensor, **kwargs) -> torch.Tensor:
+
+ if not isinstance(decoder_prompts, list):
+ decoder_prompts = [decoder_prompts]
+
+ tokens = clip.tokenize(decoder_prompts).to(predicted_pixel_values.device)
+ image = self.preprocess(predicted_pixel_values)
+
+ logits_per_image, _ = self.clip_model(image, tokens)
+
+ logits_per_image = torch.diagonal(logits_per_image)
+
+ return (1. - logits_per_image / 100).mean()
+
+
+class DINOLoss(Loss):
+ def __init__(
+ self,
+ dino_model,
+ dino_preprocess,
+ output_hidden_states: bool = False,
+ center_momentum: float = 0.9,
+ student_temp: float = 0.1,
+ teacher_temp: float = 0.04,
+ warmup_teacher_temp: float = 0.04,
+ warmup_teacher_temp_epochs: int = 30,
+ **kwargs):
+ super().__init__(**kwargs)
+
+ self.dino_model = dino_model
+ self.output_hidden_states = output_hidden_states
+ self.rescale_factor = dino_preprocess.rescale_factor
+
+ # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
+ self.preprocess = transforms.Compose(
+ [
+ transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
+ transforms.Resize(size=256),
+ transforms.CenterCrop(size=(224, 224)),
+ transforms.Normalize(mean=dino_preprocess.image_mean, std=dino_preprocess.image_std)
+ ]
+ )
+
+ self.student_temp = student_temp
+ self.teacher_temp = teacher_temp
+ self.center_momentum = center_momentum
+ self.center = torch.zeros(1, 257, 1024).to(self.accelerator.device, dtype=self.dtype)
+
+ # TODO: add temp, now fixed to 0.04
+ # we apply a warm up for the teacher temperature because
+ # a too high temperature makes the training instable at the beginning
+ # self.teacher_temp_schedule = np.concatenate((
+ # np.linspace(warmup_teacher_temp,
+ # teacher_temp, warmup_teacher_temp_epochs),
+ # np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
+ # ))
+
+ self.dino_model = self.accelerator.prepare(self.dino_model)
+
+ def forward(
+ self,
+ target: torch.Tensor,
+ predict: torch.Tensor,
+ weights: torch.Tensor = None,
+ **kwargs) -> torch.Tensor:
+
+ predict = self.preprocess(predict)
+ target = self.preprocess(target)
+
+ encoder_input = torch.cat([target, predict]).to(self.dino_model.device, dtype=self.dino_model.dtype)
+
+ if self.output_hidden_states:
+ raise ValueError("Output hidden states not supported for DINO loss.")
+ image_enc_hidden_states = self.dino_model(encoder_input, output_hidden_states=True).hidden_states[-2]
+ else:
+ image_enc_hidden_states = self.dino_model(encoder_input).last_hidden_state
+
+ teacher_output, student_output = image_enc_hidden_states.chunk(2, dim=0) # [B, 257, 1024]
+
+ student_out = student_output.float() / self.student_temp
+
+ # teacher centering and sharpening
+ # temp = self.teacher_temp_schedule[epoch]
+ temp = self.teacher_temp
+ teacher_out = F.softmax((teacher_output.float() - self.center) / temp, dim=-1)
+ teacher_out = teacher_out.detach()
+
+ loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1, keepdim=True)
+ # self.update_center(teacher_output)
+
+ if weights is not None:
+ loss = loss * weights
+ return loss.mean()
+ return loss.mean()
+
+ @torch.no_grad()
+ def update_center(self, teacher_output):
+ """
+ Update center used for teacher output.
+ """
+ batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
+ self.accelerator.reduce(batch_center, reduction="sum")
+ batch_center = batch_center / (len(teacher_output) * self.accelerator.num_processes)
+
+ # ema update
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
diff --git a/module/aggregator.py b/module/aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd615100376e019a2fd6dd6a94d176f6047c1762
--- /dev/null
+++ b/module/aggregator.py
@@ -0,0 +1,983 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders.single_file_model import FromOriginalModelMixin
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.unets.unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+)
+from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class ZeroConv(nn.Module):
+ def __init__(self, label_nc, norm_nc, mask=False):
+ super().__init__()
+ self.zero_conv = zero_module(nn.Conv2d(label_nc+norm_nc, norm_nc, 1, 1, 0))
+ self.mask = mask
+
+ def forward(self, hidden_states, h_ori=None):
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
+ c, h = hidden_states
+ if not self.mask:
+ h = self.zero_conv(torch.cat([c, h], dim=1))
+ else:
+ h = self.zero_conv(torch.cat([c, h], dim=1)) * torch.zeros_like(h)
+ if h_ori is not None:
+ h = torch.cat([h_ori, h], dim=1)
+ return h
+
+
+class SFT(nn.Module):
+ def __init__(self, label_nc, norm_nc, mask=False):
+ super().__init__()
+
+ # param_free_norm_type = str(parsed.group(1))
+ ks = 3
+ pw = ks // 2
+
+ self.mask = mask
+
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
+ nn.SiLU()
+ )
+ self.mul = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
+ self.add = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
+
+ def forward(self, hidden_states, mask=False):
+
+ c, h = hidden_states
+ mask = mask or self.mask
+ assert mask is False
+
+ actv = self.mlp_shared(c)
+ gamma = self.mul(actv)
+ beta = self.add(actv)
+
+ if self.mask:
+ gamma = gamma * torch.zeros_like(gamma)
+ beta = beta * torch.zeros_like(beta)
+ # gamma_ori, gamma_res = torch.split(gamma, [h_ori_c, h_c], dim=1)
+ # beta_ori, beta_res = torch.split(beta, [h_ori_c, h_c], dim=1)
+ # print(gamma_ori.mean(), gamma_res.mean(), beta_ori.mean(), beta_res.mean())
+ h = h * (gamma + 1) + beta
+ # sample_ori, sample_res = torch.split(h, [h_ori_c, h_c], dim=1)
+ # print(sample_ori.mean(), sample_res.mean())
+
+ return h
+
+
+@dataclass
+class AggregatorOutput(BaseOutput):
+ """
+ The output of [`Aggregator`].
+
+ Args:
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class ConditioningEmbedding(nn.Module):
+ """
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
+ [11] to convert the entire dataset of 512 ร 512 images into smaller 64 ร 64 โlatent imagesโ for stabilized
+ training. This requires ControlNets to convert image-based conditions to 64 ร 64 feature space to match the
+ convolution size. We use a tiny network E(ยท) of four convolution layers with 4 ร 4 kernels and 2 ร 2 strides
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
+ model) to encode image-space conditions ... into feature maps ..."
+ """
+
+ def __init__(
+ self,
+ conditioning_embedding_channels: int,
+ conditioning_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(block_out_channels) - 1):
+ channel_in = block_out_channels[i]
+ channel_out = block_out_channels[i + 1]
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
+
+ self.conv_out = zero_module(
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+ embedding = F.silu(embedding)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+ embedding = F.silu(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+class Aggregator(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ """
+ Aggregator model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ pad_concat: bool = False,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+ self.pad_concat = pad_concat
+
+ # Check inputs
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ # control net conditioning embedding
+ self.ref_conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ # controlnet_block = ZeroConv(output_channel, output_channel)
+ controlnet_block = nn.Sequential(
+ SFT(output_channel, output_channel),
+ zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
+ )
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ # controlnet_block = ZeroConv(output_channel, output_channel)
+ controlnet_block = nn.Sequential(
+ SFT(output_channel, output_channel),
+ zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
+ )
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ if not is_final_block:
+ # controlnet_block = ZeroConv(output_channel, output_channel)
+ controlnet_block = nn.Sequential(
+ SFT(output_channel, output_channel),
+ zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
+ )
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ # controlnet_block = ZeroConv(mid_block_channel, mid_block_channel)
+ controlnet_block = nn.Sequential(
+ SFT(mid_block_channel, mid_block_channel),
+ zero_module(nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1))
+ )
+ self.controlnet_mid_block = controlnet_block
+
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ num_layers=0,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_groups=norm_num_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ controlnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 3,
+ ):
+ r"""
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ controlnet = cls(
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=unet.config.in_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ down_block_types=unet.config.down_block_types,
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ mid_block_type=unet.config.mid_block_type,
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ conditioning_channels=conditioning_channels,
+ )
+
+ if load_weights_from_unet:
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.ref_conv_in.load_state_dict(unet.conv_in.state_dict())
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if controlnet.class_embedding:
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ if hasattr(controlnet, "add_embedding"):
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
+
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
+
+ return controlnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def process_encoder_hidden_states(
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> torch.Tensor:
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
+ return encoder_hidden_states
+
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ controlnet_cond: torch.FloatTensor,
+ cat_dim: int = -2,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[AggregatorOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
+ """
+ The [`Aggregator`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ controlnet_cond (`torch.FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for ControlNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.controlnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ else:
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. prepare input
+ cond_latent = self.conv_in(sample)
+ ref_latent = self.ref_conv_in(controlnet_cond)
+ batch_size, channel, height, width = cond_latent.shape
+ if self.pad_concat:
+ if cat_dim == -2 or cat_dim == 2:
+ concat_pad = torch.zeros(batch_size, channel, 1, width)
+ elif cat_dim == -1 or cat_dim == 3:
+ concat_pad = torch.zeros(batch_size, channel, height, 1)
+ else:
+ raise ValueError(f"Aggregator shall concat along spatial dimension, but is asked to concat dim: {cat_dim}.")
+ concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype)
+ sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim)
+ else:
+ sample = torch.cat([cond_latent, ref_latent], dim=cat_dim)
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ # rebuild sample: split and concat
+ if self.pad_concat:
+ batch_size, channel, height, width = sample.shape
+ if cat_dim == -2 or cat_dim == 2:
+ cond_latent = sample[:, :, :height//2, :]
+ ref_latent = sample[:, :, -(height//2):, :]
+ concat_pad = torch.zeros(batch_size, channel, 1, width)
+ elif cat_dim == -1 or cat_dim == 3:
+ cond_latent = sample[:, :, :, :width//2]
+ ref_latent = sample[:, :, :, -(width//2):]
+ concat_pad = torch.zeros(batch_size, channel, height, 1)
+ concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype)
+ sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim)
+ res_samples = res_samples[:-1] + (sample,)
+
+ down_block_res_samples += res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ sample = self.mid_block(
+ sample,
+ emb,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+
+ # 5. split samples and SFT.
+ controlnet_down_block_res_samples = ()
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ batch_size, channel, height, width = down_block_res_sample.shape
+ if cat_dim == -2 or cat_dim == 2:
+ cond_latent = down_block_res_sample[:, :, :height//2, :]
+ ref_latent = down_block_res_sample[:, :, -(height//2):, :]
+ elif cat_dim == -1 or cat_dim == 3:
+ cond_latent = down_block_res_sample[:, :, :, :width//2]
+ ref_latent = down_block_res_sample[:, :, :, -(width//2):]
+ down_block_res_sample = controlnet_block((cond_latent, ref_latent), )
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ batch_size, channel, height, width = sample.shape
+ if cat_dim == -2 or cat_dim == 2:
+ cond_latent = sample[:, :, :height//2, :]
+ ref_latent = sample[:, :, -(height//2):, :]
+ elif cat_dim == -1 or cat_dim == 3:
+ cond_latent = sample[:, :, :, :width//2]
+ ref_latent = sample[:, :, :, -(width//2):]
+ mid_block_res_sample = self.controlnet_mid_block((cond_latent, ref_latent), )
+
+ # 6. scaling
+ down_block_res_samples = [sample*conditioning_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample*conditioning_scale
+
+ if self.config.global_pool_conditions:
+ down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
+ ]
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
+
+ if not return_dict:
+ return (down_block_res_samples, mid_block_res_sample)
+
+ return AggregatorOutput(
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/module/attention.py b/module/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b5f2539b9e7f86a264fa7eb7c2fc3680379d28
--- /dev/null
+++ b/module/attention.py
@@ -0,0 +1,259 @@
+# Copy from diffusers.models.attention.py
+
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate, logging
+from diffusers.utils.torch_utils import maybe_allow_in_graph
+from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
+from diffusers.models.attention_processor import Attention
+from diffusers.models.embeddings import SinusoidalPositionalEmbedding
+from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
+
+from module.min_sdxl import LoRACompatibleLinear, LoRALinearLayer
+
+
+logger = logging.get_logger(__name__)
+
+def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+def maybe_grad_checkpoint(resnet, attn, hidden_states, temb, encoder_hidden_states, adapter_hidden_states, do_ckpt=True):
+
+ if do_ckpt:
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states, extracted_kv = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn), hidden_states, encoder_hidden_states, adapter_hidden_states, use_reentrant=False
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states, extracted_kv = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ adapter_hidden_states=adapter_hidden_states,
+ )
+ return hidden_states, extracted_kv
+
+
+def init_lora_in_attn(attn_module, rank: int = 4, is_kvcopy=False):
+ # Set the `lora_layer` attribute of the attention-related matrices.
+
+ attn_module.to_k.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank
+ )
+ )
+ attn_module.to_v.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank
+ )
+ )
+
+ if not is_kvcopy:
+ attn_module.to_q.set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank
+ )
+ )
+
+ attn_module.to_out[0].set_lora_layer(
+ LoRALinearLayer(
+ in_features=attn_module.to_out[0].in_features,
+ out_features=attn_module.to_out[0].out_features,
+ rank=rank,
+ )
+ )
+
+def drop_kvs(encoder_kvs, drop_chance):
+ for layer in encoder_kvs:
+ len_tokens = encoder_kvs[layer].self_attention.k.shape[1]
+ idx_to_keep = (torch.rand(len_tokens) > drop_chance)
+
+ encoder_kvs[layer].self_attention.k = encoder_kvs[layer].self_attention.k[:, idx_to_keep]
+ encoder_kvs[layer].self_attention.v = encoder_kvs[layer].self_attention.v[:, idx_to_keep]
+
+ return encoder_kvs
+
+def clone_kvs(encoder_kvs):
+ cloned_kvs = {}
+ for layer in encoder_kvs:
+ sa_cpy = KVCache(k=encoder_kvs[layer].self_attention.k.clone(),
+ v=encoder_kvs[layer].self_attention.v.clone())
+
+ ca_cpy = KVCache(k=encoder_kvs[layer].cross_attention.k.clone(),
+ v=encoder_kvs[layer].cross_attention.v.clone())
+
+ cloned_layer_cache = AttentionCache(self_attention=sa_cpy, cross_attention=ca_cpy)
+
+ cloned_kvs[layer] = cloned_layer_cache
+
+ return cloned_kvs
+
+
+class KVCache(object):
+ def __init__(self, k, v):
+ self.k = k
+ self.v = v
+
+class AttentionCache(object):
+ def __init__(self, self_attention: KVCache, cross_attention: KVCache):
+ self.self_attention = self_attention
+ self.cross_attention = cross_attention
+
+class KVCopy(nn.Module):
+ def __init__(
+ self, inner_dim, cross_attention_dim=None,
+ ):
+ super(KVCopy, self).__init__()
+
+ in_dim = cross_attention_dim or inner_dim
+
+ self.to_k = LoRACompatibleLinear(in_dim, inner_dim, bias=False)
+ self.to_v = LoRACompatibleLinear(in_dim, inner_dim, bias=False)
+
+ def forward(self, hidden_states):
+
+ k = self.to_k(hidden_states)
+ v = self.to_v(hidden_states)
+
+ return KVCache(k=k, v=v)
+
+ def init_kv_copy(self, source_attn):
+ with torch.no_grad():
+ self.to_k.weight.copy_(source_attn.to_k.weight)
+ self.to_v.weight.copy_(source_attn.to_v.weight)
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ final_dropout: bool = False,
+ inner_dim=None,
+ bias: bool = True,
+ ):
+ super().__init__()
+ if inner_dim is None:
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim, bias=bias)
+ if activation_fn == "gelu-approximate":
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
+ if final_dropout:
+ self.net.append(nn.Dropout(dropout))
+
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
+ # "feed_forward_chunk_size" can be used to save memory
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
+ raise ValueError(
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
+ )
+
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
+ ff_output = torch.cat(
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
+ dim=chunk_dim,
+ )
+ return ff_output
+
+
+@maybe_allow_in_graph
+class GatedSelfAttentionDense(nn.Module):
+ r"""
+ A gated self-attention dense layer that combines visual features and object features.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ context_dim (`int`): The number of channels in the context.
+ n_heads (`int`): The number of heads to use for attention.
+ d_head (`int`): The number of channels in each head.
+ """
+
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
+ super().__init__()
+
+ # we need a linear projection since we need cat visual feature and obj feature
+ self.linear = nn.Linear(context_dim, query_dim)
+
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
+
+ self.norm1 = nn.LayerNorm(query_dim)
+ self.norm2 = nn.LayerNorm(query_dim)
+
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
+
+ self.enabled = True
+
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
+ if not self.enabled:
+ return x
+
+ n_visual = x.shape[1]
+ objs = self.linear(objs)
+
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
+
+ return x
diff --git a/module/diffusers_vae/autoencoder_kl.py b/module/diffusers_vae/autoencoder_kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..60e897e59df853491e1fb07d06a9c21fcafd8cac
--- /dev/null
+++ b/module/diffusers_vae/autoencoder_kl.py
@@ -0,0 +1,489 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import FromOriginalVAEMixin
+from diffusers.utils.accelerate_utils import apply_forward_hook
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.modeling_outputs import AutoencoderKLOutput
+from diffusers.models.modeling_utils import ModelMixin
+from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
+ r"""
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
+ Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ force_upcast (`bool`, *optional*, default to `True`):
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
+ `force_upcast` can be set to `False` - see: https://huggingface.co./madebyollin/sdxl-vae-fp16-fix
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ scaling_factor: float = 0.18215,
+ force_upcast: float = True,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Encoder, Decoder)):
+ module.gradient_checkpointing = value
+
+ def enable_tiling(self, use_tiling: bool = True):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.use_tiling = use_tiling
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.enable_tiling(False)
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.FloatTensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ """
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded images. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.tiled_encode(x, return_dict=return_dict)
+
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self.encoder(x)
+
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ """
+ Decode a batch of images.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ output, but they should be much less noticeable.
+
+ Args:
+ x (`torch.FloatTensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
+ `tuple` is returned.
+ """
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.FloatTensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is ๐งช experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is ๐งช experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
\ No newline at end of file
diff --git a/module/diffusers_vae/vae.py b/module/diffusers_vae/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..1315f9dba257555497564e0ce334f6c3d6ac3933
--- /dev/null
+++ b/module/diffusers_vae/vae.py
@@ -0,0 +1,985 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from diffusers.utils import BaseOutput, is_torch_version
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import SpatialNorm
+from diffusers.models.unet_2d_blocks import (
+ AutoencoderTinyBlock,
+ UNetMidBlock2D,
+ get_down_block,
+ get_up_block,
+)
+
+
+@dataclass
+class DecoderOutput(BaseOutput):
+ r"""
+ Output of decoding method.
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The decoded output sample from the last layer of the model.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Encoder(nn.Module):
+ r"""
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
+ options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ double_z (`bool`, *optional*, defaults to `True`):
+ Whether to double the number of output channels for the last block.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
+ block_out_channels: Tuple[int, ...] = (64,),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ double_z: bool = True,
+ mid_block_add_attention=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[0],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=self.layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ add_downsample=not is_final_block,
+ resnet_eps=1e-6,
+ downsample_padding=0,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=None,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default",
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=None,
+ add_attention=mid_block_add_attention,
+ )
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+
+ conv_out_channels = 2 * out_channels if double_z else out_channels
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ r"""The forward method of the `Encoder` class."""
+
+ sample = self.conv_in(sample)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ # down
+ if is_torch_version(">=", "1.11.0"):
+ for down_block in self.down_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(down_block), sample, use_reentrant=False
+ )
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
+ )
+ else:
+ for down_block in self.down_blocks:
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
+
+ else:
+ # down
+ for down_block in self.down_blocks:
+ sample = down_block(sample)
+
+ # middle
+ sample = self.mid_block(sample)
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class Decoder(nn.Module):
+ r"""
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ norm_type (`str`, *optional*, defaults to `"group"`):
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int, ...] = (64,),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ norm_type: str = "group", # group, spatial
+ mid_block_add_attention=True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ temb_channels = in_channels if norm_type == "spatial" else None
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=temb_channels,
+ add_attention=mid_block_add_attention,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=temb_channels,
+ resnet_time_scale_shift=norm_type,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_type == "spatial":
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ latent_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""The forward method of the `Decoder` class."""
+
+ sample = self.conv_in(sample)
+ sample = sample.to(torch.float32)
+
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ else:
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), sample, latent_embeds
+ )
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
+ else:
+ # middle
+ sample = self.mid_block(sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+
+ # up
+ for up_block in self.up_blocks:
+ sample = up_block(sample, latent_embeds)
+
+ # post-process
+ if latent_embeds is None:
+ sample = self.conv_norm_out(sample)
+ else:
+ sample = self.conv_norm_out(sample, latent_embeds)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class UpSample(nn.Module):
+ r"""
+ The `UpSample` layer of a variational autoencoder that upsamples its input.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ r"""The forward method of the `UpSample` class."""
+ x = torch.relu(x)
+ x = self.deconv(x)
+ return x
+
+
+class MaskConditionEncoder(nn.Module):
+ """
+ used in AsymmetricAutoencoderKL
+ """
+
+ def __init__(
+ self,
+ in_ch: int,
+ out_ch: int = 192,
+ res_ch: int = 768,
+ stride: int = 16,
+ ) -> None:
+ super().__init__()
+
+ channels = []
+ while stride > 1:
+ stride = stride // 2
+ in_ch_ = out_ch * 2
+ if out_ch > res_ch:
+ out_ch = res_ch
+ if stride == 1:
+ in_ch_ = res_ch
+ channels.append((in_ch_, out_ch))
+ out_ch *= 2
+
+ out_channels = []
+ for _in_ch, _out_ch in channels:
+ out_channels.append(_out_ch)
+ out_channels.append(channels[-1][0])
+
+ layers = []
+ in_ch_ = in_ch
+ for l in range(len(out_channels)):
+ out_ch_ = out_channels[l]
+ if l == 0 or l == 1:
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
+ else:
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
+ in_ch_ = out_ch_
+
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
+ r"""The forward method of the `MaskConditionEncoder` class."""
+ out = {}
+ for l in range(len(self.layers)):
+ layer = self.layers[l]
+ x = layer(x)
+ out[str(tuple(x.shape))] = x
+ x = torch.relu(x)
+ return out
+
+
+class MaskConditionDecoder(nn.Module):
+ r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
+ decoder with a conditioner on the mask and masked image.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 3):
+ The number of input channels.
+ out_channels (`int`, *optional*, defaults to 3):
+ The number of output channels.
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
+ The number of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2):
+ The number of layers per block.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups for normalization.
+ act_fn (`str`, *optional*, defaults to `"silu"`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ norm_type (`str`, *optional*, defaults to `"group"`):
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int, ...] = (64,),
+ layers_per_block: int = 2,
+ norm_num_groups: int = 32,
+ act_fn: str = "silu",
+ norm_type: str = "group", # group, spatial
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+
+ self.conv_in = nn.Conv2d(
+ in_channels,
+ block_out_channels[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ temb_channels = in_channels if norm_type == "spatial" else None
+
+ # mid
+ self.mid_block = UNetMidBlock2D(
+ in_channels=block_out_channels[-1],
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ output_scale_factor=1,
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
+ attention_head_dim=block_out_channels[-1],
+ resnet_groups=norm_num_groups,
+ temb_channels=temb_channels,
+ )
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+
+ is_final_block = i == len(block_out_channels) - 1
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=self.layers_per_block + 1,
+ in_channels=prev_output_channel,
+ out_channels=output_channel,
+ prev_output_channel=None,
+ add_upsample=not is_final_block,
+ resnet_eps=1e-6,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ attention_head_dim=output_channel,
+ temb_channels=temb_channels,
+ resnet_time_scale_shift=norm_type,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # condition encoder
+ self.condition_encoder = MaskConditionEncoder(
+ in_ch=out_channels,
+ out_ch=block_out_channels[0],
+ res_ch=block_out_channels[-1],
+ )
+
+ # out
+ if norm_type == "spatial":
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
+ else:
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ z: torch.FloatTensor,
+ image: Optional[torch.FloatTensor] = None,
+ mask: Optional[torch.FloatTensor] = None,
+ latent_embeds: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ r"""The forward method of the `MaskConditionDecoder` class."""
+ sample = z
+ sample = self.conv_in(sample)
+
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ sample = sample.to(upscale_dtype)
+
+ # condition encoder
+ if image is not None and mask is not None:
+ masked_image = (1 - mask) * image
+ im_x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.condition_encoder),
+ masked_image,
+ mask,
+ use_reentrant=False,
+ )
+
+ # up
+ for up_block in self.up_blocks:
+ if image is not None and mask is not None:
+ sample_ = im_x[str(tuple(sample.shape))]
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
+ sample = sample * mask_ + sample_ * (1 - mask_)
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(up_block),
+ sample,
+ latent_embeds,
+ use_reentrant=False,
+ )
+ if image is not None and mask is not None:
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
+ else:
+ # middle
+ sample = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.mid_block), sample, latent_embeds
+ )
+ sample = sample.to(upscale_dtype)
+
+ # condition encoder
+ if image is not None and mask is not None:
+ masked_image = (1 - mask) * image
+ im_x = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(self.condition_encoder),
+ masked_image,
+ mask,
+ )
+
+ # up
+ for up_block in self.up_blocks:
+ if image is not None and mask is not None:
+ sample_ = im_x[str(tuple(sample.shape))]
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
+ sample = sample * mask_ + sample_ * (1 - mask_)
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
+ if image is not None and mask is not None:
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
+ else:
+ # middle
+ sample = self.mid_block(sample, latent_embeds)
+ sample = sample.to(upscale_dtype)
+
+ # condition encoder
+ if image is not None and mask is not None:
+ masked_image = (1 - mask) * image
+ im_x = self.condition_encoder(masked_image, mask)
+
+ # up
+ for up_block in self.up_blocks:
+ if image is not None and mask is not None:
+ sample_ = im_x[str(tuple(sample.shape))]
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
+ sample = sample * mask_ + sample_ * (1 - mask_)
+ sample = up_block(sample, latent_embeds)
+ if image is not None and mask is not None:
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
+
+ # post-process
+ if latent_embeds is None:
+ sample = self.conv_norm_out(sample)
+ else:
+ sample = self.conv_norm_out(sample, latent_embeds)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return sample
+
+
+class VectorQuantizer(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
+ multiplications and allows for post-hoc remapping of indices.
+ """
+
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(
+ self,
+ n_e: int,
+ vq_embed_dim: int,
+ beta: float,
+ remap=None,
+ unknown_index: str = "random",
+ sane_index_shape: bool = False,
+ legacy: bool = True,
+ ):
+ super().__init__()
+ self.n_e = n_e
+ self.vq_embed_dim = vq_embed_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.used: torch.Tensor
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed + 1
+ print(
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices."
+ )
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ match = (inds[:, :, None] == used[None, None, ...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2) < 1
+ if self.unknown_index == "random":
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
+ ishape = inds.shape
+ assert len(ishape) > 1
+ inds = inds.reshape(ishape[0], -1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.vq_embed_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
+
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q: torch.FloatTensor = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0], -1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q: torch.FloatTensor = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
+ )
+
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
+ # make sure sample is on the same device as the parameters and has same dtype
+ sample = randn_tensor(
+ self.mean.shape,
+ generator=generator,
+ device=self.parameters.device,
+ dtype=self.parameters.dtype,
+ )
+ x = self.mean + self.std * sample
+ return x
+
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3],
+ )
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var
+ - 1.0
+ - self.logvar
+ + other.logvar,
+ dim=[1, 2, 3],
+ )
+
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
+ if self.deterministic:
+ return torch.Tensor([0.0])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims,
+ )
+
+ def mode(self) -> torch.Tensor:
+ return self.mean
+
+
+class EncoderTiny(nn.Module):
+ r"""
+ The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
+
+ Args:
+ in_channels (`int`):
+ The number of input channels.
+ out_channels (`int`):
+ The number of output channels.
+ num_blocks (`Tuple[int, ...]`):
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
+ use.
+ block_out_channels (`Tuple[int, ...]`):
+ The number of output channels for each block.
+ act_fn (`str`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_blocks: Tuple[int, ...],
+ block_out_channels: Tuple[int, ...],
+ act_fn: str,
+ ):
+ super().__init__()
+
+ layers = []
+ for i, num_block in enumerate(num_blocks):
+ num_channels = block_out_channels[i]
+
+ if i == 0:
+ layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
+ else:
+ layers.append(
+ nn.Conv2d(
+ num_channels,
+ num_channels,
+ kernel_size=3,
+ padding=1,
+ stride=2,
+ bias=False,
+ )
+ )
+
+ for _ in range(num_block):
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
+
+ layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
+
+ self.layers = nn.Sequential(*layers)
+ self.gradient_checkpointing = False
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ r"""The forward method of the `EncoderTiny` class."""
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
+ else:
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
+
+ else:
+ # scale image from [-1, 1] to [0, 1] to match TAESD convention
+ x = self.layers(x.add(1).div(2))
+
+ return x
+
+
+class DecoderTiny(nn.Module):
+ r"""
+ The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
+
+ Args:
+ in_channels (`int`):
+ The number of input channels.
+ out_channels (`int`):
+ The number of output channels.
+ num_blocks (`Tuple[int, ...]`):
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
+ use.
+ block_out_channels (`Tuple[int, ...]`):
+ The number of output channels for each block.
+ upsampling_scaling_factor (`int`):
+ The scaling factor to use for upsampling.
+ act_fn (`str`):
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_blocks: Tuple[int, ...],
+ block_out_channels: Tuple[int, ...],
+ upsampling_scaling_factor: int,
+ act_fn: str,
+ ):
+ super().__init__()
+
+ layers = [
+ nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
+ get_activation(act_fn),
+ ]
+
+ for i, num_block in enumerate(num_blocks):
+ is_final_block = i == (len(num_blocks) - 1)
+ num_channels = block_out_channels[i]
+
+ for _ in range(num_block):
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
+
+ if not is_final_block:
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
+
+ conv_out_channel = num_channels if not is_final_block else out_channels
+ layers.append(
+ nn.Conv2d(
+ num_channels,
+ conv_out_channel,
+ kernel_size=3,
+ padding=1,
+ bias=is_final_block,
+ )
+ )
+
+ self.layers = nn.Sequential(*layers)
+ self.gradient_checkpointing = False
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ r"""The forward method of the `DecoderTiny` class."""
+ # Clamp.
+ x = torch.tanh(x / 3) * 3
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
+ else:
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
+
+ else:
+ x = self.layers(x)
+
+ # scale image from [0, 1] to [-1, 1] to match diffusers convention
+ return x.mul(2).sub(1)
\ No newline at end of file
diff --git a/module/ip_adapter/attention_processor.py b/module/ip_adapter/attention_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed6cf755f7c2b36c81e6527d64bd0cc4e749d696
--- /dev/null
+++ b/module/ip_adapter/attention_processor.py
@@ -0,0 +1,1467 @@
+# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class AdaLayerNorm(nn.Module):
+ def __init__(self, embedding_dim: int, time_embedding_dim: int = None):
+ super().__init__()
+
+ if time_embedding_dim is None:
+ time_embedding_dim = embedding_dim
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
+ nn.init.zeros_(self.linear.weight)
+ nn.init.zeros_(self.linear.bias)
+
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
+
+ def forward(
+ self, x: torch.Tensor, timestep_embedding: torch.Tensor
+ ):
+ emb = self.linear(self.silu(timestep_embedding))
+ shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class AttnProcessor(nn.Module):
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAttnProcessor(nn.Module):
+ r"""
+ Attention processor for IP-Adapater.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class TA_IPAttnProcessor(nn.Module):
+ r"""
+ Attention processor for IP-Adapater.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
+ self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ # time-dependent adaLN
+ ip_key = self.ln_k_ip(ip_key, temb)
+ ip_value = self.ln_v_ip(ip_value, temb)
+
+ ip_key = attn.head_to_batch_dim(ip_key)
+ ip_value = attn.head_to_batch_dim(ip_value)
+
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ external_kv=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if external_kv:
+ key = torch.cat([key, external_kv.k], axis=1)
+ value = torch.cat([value, external_kv.v], axis=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class split_AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ time_embedding_dim=None,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ external_kv=None,
+ temb=None,
+ cat_dim=-2,
+ original_shape=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ # 2d to sequence.
+ height, width = hidden_states.shape[-2:]
+ if cat_dim==-2 or cat_dim==2:
+ hidden_states_0 = hidden_states[:, :, :height//2, :]
+ hidden_states_1 = hidden_states[:, :, -(height//2):, :]
+ elif cat_dim==-1 or cat_dim==3:
+ hidden_states_0 = hidden_states[:, :, :, :width//2]
+ hidden_states_1 = hidden_states[:, :, :, -(width//2):]
+ batch_size, channel, height, width = hidden_states_0.shape
+ hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
+ hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
+ else:
+ # directly split sqeuence according to concat dim.
+ single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
+ hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
+ hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
+
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=1)
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ if external_kv:
+ key = torch.cat([key, external_kv.k], dim=1)
+ value = torch.cat([value, external_kv.v], dim=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ # spatially split.
+ hidden_states_0, hidden_states_1 = hidden_states.chunk(2, dim=1)
+
+ if input_ndim == 4:
+ hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if cat_dim==-2 or cat_dim==2:
+ hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
+ elif cat_dim==-1 or cat_dim==3:
+ hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
+ else:
+ batch_size, sequence_length, inner_dim = hidden_states.shape
+ hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class sep_split_AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size=None,
+ cross_attention_dim=None,
+ time_embedding_dim=None,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.ln_k_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
+ self.ln_v_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
+ # self.hidden_size = hidden_size
+ # self.cross_attention_dim = cross_attention_dim
+ # self.scale = scale
+ # self.num_tokens = num_tokens
+
+ # self.to_q_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ # self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ # self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ external_kv=None,
+ temb=None,
+ cat_dim=-2,
+ original_shape=None,
+ ref_scale=1.0,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ # 2d to sequence.
+ height, width = hidden_states.shape[-2:]
+ if cat_dim==-2 or cat_dim==2:
+ hidden_states_0 = hidden_states[:, :, :height//2, :]
+ hidden_states_1 = hidden_states[:, :, -(height//2):, :]
+ elif cat_dim==-1 or cat_dim==3:
+ hidden_states_0 = hidden_states[:, :, :, :width//2]
+ hidden_states_1 = hidden_states[:, :, :, -(width//2):]
+ batch_size, channel, height, width = hidden_states_0.shape
+ hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
+ hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
+ else:
+ # directly split sqeuence according to concat dim.
+ single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
+ hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
+ hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
+
+ batch_size, sequence_length, _ = (
+ hidden_states_0.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states_0 = attn.group_norm(hidden_states_0.transpose(1, 2)).transpose(1, 2)
+ hidden_states_1 = attn.group_norm(hidden_states_1.transpose(1, 2)).transpose(1, 2)
+
+ query_0 = attn.to_q(hidden_states_0)
+ query_1 = attn.to_q(hidden_states_1)
+ key_0 = attn.to_k(hidden_states_0)
+ key_1 = attn.to_k(hidden_states_1)
+ value_0 = attn.to_v(hidden_states_0)
+ value_1 = attn.to_v(hidden_states_1)
+
+ # time-dependent adaLN
+ key_1 = self.ln_k_ref(key_1, temb)
+ value_1 = self.ln_v_ref(value_1, temb)
+
+ if external_kv:
+ key_1 = torch.cat([key_1, external_kv.k], dim=1)
+ value_1 = torch.cat([value_1, external_kv.v], dim=1)
+
+ inner_dim = key_0.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query_0 = query_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ query_1 = query_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_0 = key_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ key_1 = key_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_0 = value_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value_1 = value_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states_0 = F.scaled_dot_product_attention(
+ query_0, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_1 = F.scaled_dot_product_attention(
+ query_1, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ # cross-attn
+ _hidden_states_0 = F.scaled_dot_product_attention(
+ query_0, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_0 = hidden_states_0 + ref_scale * _hidden_states_0 * 10
+
+ # TODO: drop this cross-attn
+ _hidden_states_1 = F.scaled_dot_product_attention(
+ query_1, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_1 = hidden_states_1 + ref_scale * _hidden_states_1
+
+ hidden_states_0 = hidden_states_0.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_1 = hidden_states_1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states_0 = hidden_states_0.to(query_0.dtype)
+ hidden_states_1 = hidden_states_1.to(query_1.dtype)
+
+
+ # linear proj
+ hidden_states_0 = attn.to_out[0](hidden_states_0)
+ hidden_states_1 = attn.to_out[0](hidden_states_1)
+ # dropout
+ hidden_states_0 = attn.to_out[1](hidden_states_0)
+ hidden_states_1 = attn.to_out[1](hidden_states_1)
+
+
+ if input_ndim == 4:
+ hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if cat_dim==-2 or cat_dim==2:
+ hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
+ elif cat_dim==-1 or cat_dim==3:
+ hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
+ else:
+ batch_size, sequence_length, inner_dim = hidden_states.shape
+ hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class AdditiveKV_AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = None,
+ cross_attention_dim: int = None,
+ time_embedding_dim: int = None,
+ additive_scale: float = 1.0,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.additive_scale = additive_scale
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ external_kv=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ if external_kv:
+ key = external_kv.k
+ value = external_kv.v
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ external_attn_output = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states + self.additive_scale * external_attn_output
+
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class TA_AdditiveKV_AttnProcessor2_0(torch.nn.Module):
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(
+ self,
+ hidden_size: int = None,
+ cross_attention_dim: int = None,
+ time_embedding_dim: int = None,
+ additive_scale: float = 1.0,
+ ):
+ super().__init__()
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.ln_k = AdaLayerNorm(hidden_size, time_embedding_dim)
+ self.ln_v = AdaLayerNorm(hidden_size, time_embedding_dim)
+ self.additive_scale = additive_scale
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ external_kv=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+
+ if external_kv:
+ key = external_kv.k
+ value = external_kv.v
+
+ # time-dependent adaLN
+ key = self.ln_k(key, temb)
+ value = self.ln_v(value, temb)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ external_attn_output = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states + self.additive_scale * external_attn_output
+
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class IPAttnProcessor2_0(torch.nn.Module):
+ r"""
+ Attention processor for IP-Adapater for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ if isinstance(encoder_hidden_states, tuple):
+ # FIXME: now hard coded to single image prompt.
+ batch_size, _, hid_dim = encoder_hidden_states[0].shape
+ ip_tokens = encoder_hidden_states[1][0]
+ encoder_hidden_states = torch.cat([encoder_hidden_states[0], ip_tokens], dim=1)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class TA_IPAttnProcessor2_0(torch.nn.Module):
+ r"""
+ Attention processor for IP-Adapater for PyTorch 2.0.
+ Args:
+ hidden_size (`int`):
+ The hidden size of the attention layer.
+ cross_attention_dim (`int`):
+ The number of channels in the `encoder_hidden_states`.
+ scale (`float`, defaults to 1.0):
+ the weight scale of image prompt.
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
+ The context length of the image features.
+ """
+
+ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+ self.scale = scale
+ self.num_tokens = num_tokens
+
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
+ self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
+ self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ external_kv=None,
+ temb=None,
+ ):
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
+
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ if not isinstance(encoder_hidden_states, tuple):
+ # get encoder_hidden_states, ip_hidden_states
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states, ip_hidden_states = (
+ encoder_hidden_states[:, :end_pos, :],
+ encoder_hidden_states[:, end_pos:, :],
+ )
+ else:
+ # FIXME: now hard coded to single image prompt.
+ batch_size, _, hid_dim = encoder_hidden_states[0].shape
+ ip_hidden_states = encoder_hidden_states[1][0]
+ encoder_hidden_states = encoder_hidden_states[0]
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if external_kv:
+ key = torch.cat([key, external_kv.k], axis=1)
+ value = torch.cat([value, external_kv.v], axis=1)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # for ip-adapter
+ ip_key = self.to_k_ip(ip_hidden_states)
+ ip_value = self.to_v_ip(ip_hidden_states)
+
+ # time-dependent adaLN
+ ip_key = self.ln_k_ip(ip_key, temb)
+ ip_value = self.ln_v_ip(ip_value, temb)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ ip_hidden_states = F.scaled_dot_product_attention(
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
+
+ hidden_states = hidden_states + self.scale * ip_hidden_states
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+## for controlnet
+class CNAttnProcessor:
+ r"""
+ Default processor for performing attention-related computations.
+ """
+
+ def __init__(self, num_tokens=4):
+ self.num_tokens = num_tokens
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.head_to_batch_dim(query)
+ key = attn.head_to_batch_dim(key)
+ value = attn.head_to_batch_dim(value)
+
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
+ hidden_states = torch.bmm(attention_probs, value)
+ hidden_states = attn.batch_to_head_dim(hidden_states)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class CNAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __init__(self, num_tokens=4):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+ self.num_tokens = num_tokens
+
+ def __call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=None,
+ temb=None,
+ ):
+ residual = hidden_states
+
+ if attn.spatial_norm is not None:
+ hidden_states = attn.spatial_norm(hidden_states, temb)
+
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ else:
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
+ if attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ # TODO: add support for attn.scale when we move to Torch 2.1
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, use_external_kv=False):
+ attn_procs = {}
+ unet_sd = unet.state_dict()
+ for name in unet.attn_processors.keys():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is None:
+ if use_external_kv:
+ attn_procs[name] = AdditiveKV_AttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ time_embedding_dim=1280,
+ ) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor()
+ else:
+ attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
+ else:
+ if use_adaln:
+ layer_name = name.split(".processor")[0]
+ if use_lcm:
+ weights = {
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.base_layer.weight"],
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.base_layer.weight"],
+ }
+ else:
+ weights = {
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
+ }
+ attn_procs[name] = TA_IPAttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ num_tokens=ip_adapter_tokens,
+ time_embedding_dim=1280,
+ ) if hasattr(F, "scaled_dot_product_attention") else \
+ TA_IPAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ num_tokens=ip_adapter_tokens,
+ time_embedding_dim=1280,
+ )
+ attn_procs[name].load_state_dict(weights, strict=False)
+ else:
+ attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
+
+ return attn_procs
+
+
+def init_aggregator_attn_proc(unet, use_adaln=False, split_attn=False):
+ attn_procs = {}
+ unet_sd = unet.state_dict()
+ for name in unet.attn_processors.keys():
+ # get layer name and hidden dim
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ # init attn proc
+ if split_attn:
+ # layer_name = name.split(".processor")[0]
+ # weights = {
+ # "to_q_ref.weight": unet_sd[layer_name + ".to_q.weight"],
+ # "to_k_ref.weight": unet_sd[layer_name + ".to_k.weight"],
+ # "to_v_ref.weight": unet_sd[layer_name + ".to_v.weight"],
+ # }
+ attn_procs[name] = (
+ sep_split_AttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=hidden_size,
+ time_embedding_dim=1280,
+ )
+ if use_adaln
+ else split_AttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ time_embedding_dim=1280,
+ )
+ )
+ # attn_procs[name].load_state_dict(weights, strict=False)
+ else:
+ attn_procs[name] = (
+ AttnProcessor2_0(
+ hidden_size=hidden_size,
+ cross_attention_dim=hidden_size,
+ )
+ if hasattr(F, "scaled_dot_product_attention")
+ else AttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=hidden_size,
+ )
+ )
+
+ return attn_procs
diff --git a/module/ip_adapter/ip_adapter.py b/module/ip_adapter/ip_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..10f01d4f38f95fc930af2db66bef72757cc54455
--- /dev/null
+++ b/module/ip_adapter/ip_adapter.py
@@ -0,0 +1,236 @@
+import os
+import torch
+from typing import List
+from collections import namedtuple, OrderedDict
+
+def is_torch2_available():
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention")
+
+if is_torch2_available():
+ from .attention_processor import (
+ AttnProcessor2_0 as AttnProcessor,
+ )
+ from .attention_processor import (
+ CNAttnProcessor2_0 as CNAttnProcessor,
+ )
+ from .attention_processor import (
+ IPAttnProcessor2_0 as IPAttnProcessor,
+ )
+ from .attention_processor import (
+ TA_IPAttnProcessor2_0 as TA_IPAttnProcessor,
+ )
+else:
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor, TA_IPAttnProcessor
+
+
+class ImageProjModel(torch.nn.Module):
+ """Projection Model"""
+
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
+ super().__init__()
+
+ self.cross_attention_dim = cross_attention_dim
+ self.clip_extra_context_tokens = clip_extra_context_tokens
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds):
+ embeds = image_embeds
+ clip_extra_context_tokens = self.proj(embeds).reshape(
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
+ )
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+ return clip_extra_context_tokens
+
+
+class MLPProjModel(torch.nn.Module):
+ """SD model with image prompt"""
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
+ torch.nn.GELU(),
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
+ torch.nn.LayerNorm(cross_attention_dim)
+ )
+
+ def forward(self, image_embeds):
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+class MultiIPAdapterImageProjection(torch.nn.Module):
+ def __init__(self, IPAdapterImageProjectionLayers):
+ super().__init__()
+ self.image_projection_layers = torch.nn.ModuleList(IPAdapterImageProjectionLayers)
+
+ def forward(self, image_embeds: List[torch.FloatTensor]):
+ projected_image_embeds = []
+
+ # currently, we accept `image_embeds` as
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
+ if not isinstance(image_embeds, list):
+ image_embeds = [image_embeds.unsqueeze(1)]
+
+ if len(image_embeds) != len(self.image_projection_layers):
+ raise ValueError(
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
+ )
+
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
+ image_embed = image_projection_layer(image_embed)
+ # image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
+
+ projected_image_embeds.append(image_embed)
+
+ return projected_image_embeds
+
+
+class IPAdapter(torch.nn.Module):
+ """IP-Adapter"""
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
+ super().__init__()
+ self.unet = unet
+ self.image_proj = image_proj_model
+ self.ip_adapter = adapter_modules
+
+ if ckpt_path is not None:
+ self.load_from_checkpoint(ckpt_path)
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
+ ip_tokens = self.image_proj(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ return noise_pred
+
+ def load_from_checkpoint(self, ckpt_path: str):
+ # Calculate original checksums
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
+
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+ keys = list(state_dict.keys())
+ if keys != ["image_proj", "ip_adapter"]:
+ state_dict = revise_state_dict(state_dict)
+
+ # Load state dict for image_proj_model and adapter_modules
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
+ self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=True)
+
+ # Calculate new checksums
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
+
+ # Verify if the weights have changed
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
+
+
+class IPAdapterPlus(torch.nn.Module):
+ """IP-Adapter"""
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
+ super().__init__()
+ self.unet = unet
+ self.image_proj = image_proj_model
+ self.ip_adapter = adapter_modules
+
+ if ckpt_path is not None:
+ self.load_from_checkpoint(ckpt_path)
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
+ ip_tokens = self.image_proj(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
+ return noise_pred
+
+ def load_from_checkpoint(self, ckpt_path: str):
+ # Calculate original checksums
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
+ org_unet_sum = []
+ for attn_name, attn_proc in self.unet.attn_processors.items():
+ if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)):
+ org_unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()])))
+ org_unet_sum = torch.sum(torch.stack(org_unet_sum))
+
+ state_dict = torch.load(ckpt_path, map_location="cpu")
+ keys = list(state_dict.keys())
+ if keys != ["image_proj", "ip_adapter"]:
+ state_dict = revise_state_dict(state_dict)
+
+ # Check if 'latents' exists in both the saved state_dict and the current model's state_dict
+ strict_load_image_proj_model = True
+ if "latents" in state_dict["image_proj"] and "latents" in self.image_proj.state_dict():
+ # Check if the shapes are mismatched
+ if state_dict["image_proj"]["latents"].shape != self.image_proj.state_dict()["latents"].shape:
+ print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
+ print("Removing 'latents' from checkpoint and loading the rest of the weights.")
+ del state_dict["image_proj"]["latents"]
+ strict_load_image_proj_model = False
+
+ # Load state dict for image_proj_model and adapter_modules
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
+ missing_key, unexpected_key = self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=False)
+ if len(missing_key) > 0:
+ for ms in missing_key:
+ if "ln" not in ms:
+ raise ValueError(f"Missing key in adapter_modules: {len(missing_key)}")
+ if len(unexpected_key) > 0:
+ raise ValueError(f"Unexpected key in adapter_modules: {len(unexpected_key)}")
+
+ # Calculate new checksums
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
+
+ # Verify if the weights loaded to unet
+ unet_sum = []
+ for attn_name, attn_proc in self.unet.attn_processors.items():
+ if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)):
+ unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()])))
+ unet_sum = torch.sum(torch.stack(unet_sum))
+
+ assert org_unet_sum != unet_sum, "Weights of adapter_modules in unet did not change!"
+ assert (unet_sum - new_adapter_sum < 1e-4), "Weights of adapter_modules did not load to unet!"
+
+ # Verify if the weights have changed
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_mod`ules did not change!"
+
+
+class IPAdapterXL(IPAdapter):
+ """SDXL"""
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
+ ip_tokens = self.image_proj(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
+ return noise_pred
+
+
+class IPAdapterPlusXL(IPAdapterPlus):
+ """IP-Adapter with fine-grained features"""
+
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
+ ip_tokens = self.image_proj(image_embeds)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
+ # Predict the noise residual
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
+ return noise_pred
+
+
+class IPAdapterFull(IPAdapterPlus):
+ """IP-Adapter with full features"""
+
+ def init_proj(self):
+ image_proj_model = MLPProjModel(
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
+ ).to(self.device, dtype=torch.float16)
+ return image_proj_model
diff --git a/module/ip_adapter/resampler.py b/module/ip_adapter/resampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..72295f90b1e9fdfe717de9af317fc55e5d466a09
--- /dev/null
+++ b/module/ip_adapter/resampler.py
@@ -0,0 +1,158 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
+
+import math
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from einops.layers.torch import Rearrange
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+ def __init__(
+ self,
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=64,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ max_seq_len: int = 257, # CLIP tokens + CLS token
+ apply_pos_emb: bool = False,
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
+ ):
+ super().__init__()
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+
+ self.proj_in = nn.Linear(embedding_dim, dim)
+
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.to_latents_from_mean_pooled_seq = (
+ nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, dim * num_latents_mean_pooled),
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
+ )
+ if num_latents_mean_pooled > 0
+ else None
+ )
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ if self.pos_emb is not None:
+ n, device = x.shape[1], x.device
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
+ x = x + pos_emb
+
+ latents = self.latents.repeat(x.size(0), 1, 1)
+
+ x = self.proj_in(x)
+
+ if self.to_latents_from_mean_pooled_seq:
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ return self.norm_out(latents)
+
+
+def masked_mean(t, *, dim, mask=None):
+ if mask is None:
+ return t.mean(dim=dim)
+
+ denom = mask.sum(dim=dim, keepdim=True)
+ mask = rearrange(mask, "b n -> b n 1")
+ masked_t = t.masked_fill(~mask, 0.0)
+
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
diff --git a/module/ip_adapter/utils.py b/module/ip_adapter/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..64c45cd85ac8aa2a03ad5fa66429ddd46824cd69
--- /dev/null
+++ b/module/ip_adapter/utils.py
@@ -0,0 +1,248 @@
+import torch
+from collections import namedtuple, OrderedDict
+from safetensors import safe_open
+from .attention_processor import init_attn_proc
+from .ip_adapter import MultiIPAdapterImageProjection
+from .resampler import Resampler
+from transformers import (
+ AutoModel, AutoImageProcessor,
+ CLIPVisionModelWithProjection, CLIPImageProcessor)
+
+
+def init_adapter_in_unet(
+ unet,
+ image_proj_model=None,
+ pretrained_model_path_or_dict=None,
+ adapter_tokens=64,
+ embedding_dim=None,
+ use_lcm=False,
+ use_adaln=True,
+ ):
+ device = unet.device
+ dtype = unet.dtype
+ if image_proj_model is None:
+ assert embedding_dim is not None, "embedding_dim must be provided if image_proj_model is None."
+ image_proj_model = Resampler(
+ embedding_dim=embedding_dim,
+ output_dim=unet.config.cross_attention_dim,
+ num_queries=adapter_tokens,
+ )
+ if pretrained_model_path_or_dict is not None:
+ if not isinstance(pretrained_model_path_or_dict, dict):
+ if pretrained_model_path_or_dict.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(pretrained_model_path_or_dict, framework="pt", device=unet.device) as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(pretrained_model_path_or_dict, map_location=unet.device)
+ else:
+ state_dict = pretrained_model_path_or_dict
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ state_dict = revise_state_dict(state_dict)
+
+ # Creat IP cross-attention in unet.
+ attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
+ unet.set_attn_processor(attn_procs)
+
+ # Load pretrinaed model if needed.
+ if pretrained_model_path_or_dict is not None:
+ if "ip_adapter" in state_dict.keys():
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
+ missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
+ for mk in missing:
+ if "ln" not in mk:
+ raise ValueError(f"Missing keys in adapter_modules: {missing}")
+ if "image_proj" in state_dict.keys():
+ image_proj_model.load_state_dict(state_dict["image_proj"])
+
+ # Load image projectors into iterable ModuleList.
+ image_projection_layers = []
+ image_projection_layers.append(image_proj_model)
+ unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
+
+ # Adjust unet config to handle addtional ip hidden states.
+ unet.config.encoder_hid_dim_type = "ip_image_proj"
+ unet.to(dtype=dtype, device=device)
+
+
+def load_adapter_to_pipe(
+ pipe,
+ pretrained_model_path_or_dict,
+ image_encoder_or_path=None,
+ feature_extractor_or_path=None,
+ use_clip_encoder=False,
+ adapter_tokens=64,
+ use_lcm=False,
+ use_adaln=True,
+ ):
+
+ if not isinstance(pretrained_model_path_or_dict, dict):
+ if pretrained_model_path_or_dict.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
+ else:
+ state_dict = pretrained_model_path_or_dict
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ state_dict = revise_state_dict(state_dict)
+
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
+ if image_encoder_or_path is not None:
+ if isinstance(image_encoder_or_path, str):
+ feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
+
+ image_encoder_or_path = (
+ CLIPVisionModelWithProjection.from_pretrained(
+ image_encoder_or_path
+ ) if use_clip_encoder else
+ AutoModel.from_pretrained(image_encoder_or_path)
+ )
+
+ if feature_extractor_or_path is not None:
+ if isinstance(feature_extractor_or_path, str):
+ feature_extractor_or_path = (
+ CLIPImageProcessor() if use_clip_encoder else
+ AutoImageProcessor.from_pretrained(feature_extractor_or_path)
+ )
+
+ # create image encoder if it has not been registered to the pipeline yet
+ if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
+ image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
+ pipe.register_modules(image_encoder=image_encoder)
+ else:
+ image_encoder = pipe.image_encoder
+
+ # create feature extractor if it has not been registered to the pipeline yet
+ if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
+ feature_extractor = feature_extractor_or_path
+ pipe.register_modules(feature_extractor=feature_extractor)
+ else:
+ feature_extractor = pipe.feature_extractor
+
+ # load adapter into unet
+ unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
+ attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
+ unet.set_attn_processor(attn_procs)
+ image_proj_model = Resampler(
+ embedding_dim=image_encoder.config.hidden_size,
+ output_dim=unet.config.cross_attention_dim,
+ num_queries=adapter_tokens,
+ )
+
+ # Load pretrinaed model if needed.
+ if "ip_adapter" in state_dict.keys():
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
+ missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
+ for mk in missing:
+ if "ln" not in mk:
+ raise ValueError(f"Missing keys in adapter_modules: {missing}")
+ if "image_proj" in state_dict.keys():
+ image_proj_model.load_state_dict(state_dict["image_proj"])
+
+ # convert IP-Adapter Image Projection layers to diffusers
+ image_projection_layers = []
+ image_projection_layers.append(image_proj_model)
+ unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
+
+ # Adjust unet config to handle addtional ip hidden states.
+ unet.config.encoder_hid_dim_type = "ip_image_proj"
+ unet.to(dtype=pipe.dtype, device=pipe.device)
+
+
+def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
+ new_state_dict = OrderedDict()
+ new_state_dict["image_proj"] = OrderedDict()
+ new_state_dict["ip_adapter"] = OrderedDict()
+ if isinstance(old_state_dict_or_path, str):
+ old_state_dict = torch.load(old_state_dict_or_path, map_location=map_location)
+ else:
+ old_state_dict = old_state_dict_or_path
+ for name, weight in old_state_dict.items():
+ if name.startswith("image_proj_model."):
+ new_state_dict["image_proj"][name[len("image_proj_model."):]] = weight
+ elif name.startswith("adapter_modules."):
+ new_state_dict["ip_adapter"][name[len("adapter_modules."):]] = weight
+ return new_state_dict
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+def encode_image(image_encoder, feature_extractor, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_enc_hidden_states
+ else:
+ if isinstance(image_encoder, CLIPVisionModelWithProjection):
+ # CLIP image encoder.
+ image_embeds = image_encoder(image).image_embeds
+ else:
+ # DINO image encoder.
+ image_embeds = image_encoder(image).last_hidden_state
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+
+def prepare_training_image_embeds(
+ image_encoder, feature_extractor,
+ ip_adapter_image, ip_adapter_image_embeds,
+ device, drop_rate, output_hidden_state, idx_to_replace=None
+):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ # if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
+ # raise ValueError(
+ # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ # )
+
+ image_embeds = []
+ for single_ip_adapter_image in ip_adapter_image:
+ if idx_to_replace is None:
+ idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
+ zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
+ single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
+ single_image_embeds = encode_image(
+ image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
\ No newline at end of file
diff --git a/module/min_sdxl.py b/module/min_sdxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd94b03dc8f7a1484888085b440a66f99a4b0275
--- /dev/null
+++ b/module/min_sdxl.py
@@ -0,0 +1,915 @@
+# Modified from minSDXL by Simo Ryu:
+# https://github.com/cloneofsimo/minSDXL ,
+# which is in turn modified from the original code of:
+# https://github.com/huggingface/diffusers
+# So has APACHE 2.0 license
+
+from typing import Optional, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+import inspect
+
+from collections import namedtuple
+
+from torch.fft import fftn, fftshift, ifftn, ifftshift
+
+from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
+
+# Implementation of FreeU for minsdxl
+
+def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
+ """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
+
+ This version of the method comes from here:
+ https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
+ """
+ x = x_in
+ B, C, H, W = x.shape
+
+ # Non-power of 2 images must be float32
+ if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
+ x = x.to(dtype=torch.float32)
+
+ # FFT
+ x_freq = fftn(x, dim=(-2, -1))
+ x_freq = fftshift(x_freq, dim=(-2, -1))
+
+ B, C, H, W = x_freq.shape
+ mask = torch.ones((B, C, H, W), device=x.device)
+
+ crow, ccol = H // 2, W // 2
+ mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
+ x_freq = x_freq * mask
+
+ # IFFT
+ x_freq = ifftshift(x_freq, dim=(-2, -1))
+ x_filtered = ifftn(x_freq, dim=(-2, -1)).real
+
+ return x_filtered.to(dtype=x_in.dtype)
+
+
+def apply_freeu(
+ resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs):
+ """Applies the FreeU mechanism as introduced in https:
+ //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
+
+ Args:
+ resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
+ hidden_states (`torch.Tensor`): Inputs to the underlying block.
+ res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
+ s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
+ s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ if resolution_idx == 0:
+ num_half_channels = hidden_states.shape[1] // 2
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
+ if resolution_idx == 1:
+ num_half_channels = hidden_states.shape[1] // 2
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
+
+ return hidden_states, res_hidden_states
+
+# Diffusers-style LoRA to keep everything in the min_sdxl.py file
+
+class LoRALinearLayer(nn.Module):
+ r"""
+ A linear layer that is used with LoRA.
+
+ Parameters:
+ in_features (`int`):
+ Number of input features.
+ out_features (`int`):
+ Number of output features.
+ rank (`int`, `optional`, defaults to 4):
+ The rank of the LoRA layer.
+ network_alpha (`float`, `optional`, defaults to `None`):
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ device (`torch.device`, `optional`, defaults to `None`):
+ The device to use for the layer's weights.
+ dtype (`torch.dtype`, `optional`, defaults to `None`):
+ The dtype to use for the layer's weights.
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ rank: int = 4,
+ network_alpha: Optional[float] = None,
+ device: Optional[Union[torch.device, str]] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ super().__init__()
+
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
+ self.network_alpha = network_alpha
+ self.rank = rank
+ self.out_features = out_features
+ self.in_features = in_features
+
+ nn.init.normal_(self.down.weight, std=1 / rank)
+ nn.init.zeros_(self.up.weight)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ orig_dtype = hidden_states.dtype
+ dtype = self.down.weight.dtype
+
+ down_hidden_states = self.down(hidden_states.to(dtype))
+ up_hidden_states = self.up(down_hidden_states)
+
+ if self.network_alpha is not None:
+ up_hidden_states *= self.network_alpha / self.rank
+
+ return up_hidden_states.to(orig_dtype)
+
+class LoRACompatibleLinear(nn.Linear):
+ """
+ A Linear layer that can be used with LoRA.
+ """
+
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lora_layer = lora_layer
+
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
+ self.lora_layer = lora_layer
+
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
+ if self.lora_layer is None:
+ return
+
+ dtype, device = self.weight.data.dtype, self.weight.data.device
+
+ w_orig = self.weight.data.float()
+ w_up = self.lora_layer.up.weight.data.float()
+ w_down = self.lora_layer.down.weight.data.float()
+
+ if self.lora_layer.network_alpha is not None:
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
+
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+
+ if safe_fusing and torch.isnan(fused_weight).any().item():
+ raise ValueError(
+ "This LoRA weight seems to be broken. "
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
+ "LoRA weights will not be fused."
+ )
+
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
+
+ # we can drop the lora layer now
+ self.lora_layer = None
+
+ # offload the up and down matrices to CPU to not blow the memory
+ self.w_up = w_up.cpu()
+ self.w_down = w_down.cpu()
+ self._lora_scale = lora_scale
+
+ def _unfuse_lora(self):
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
+ return
+
+ fused_weight = self.weight.data
+ dtype, device = fused_weight.dtype, fused_weight.device
+
+ w_up = self.w_up.to(device=device).float()
+ w_down = self.w_down.to(device).float()
+
+ unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
+
+ self.w_up = None
+ self.w_down = None
+
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
+ if self.lora_layer is None:
+ out = super().forward(hidden_states)
+ return out
+ else:
+ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
+ return out
+
+class Timesteps(nn.Module):
+ def __init__(self, num_channels: int = 320):
+ super().__init__()
+ self.num_channels = num_channels
+
+ def forward(self, timesteps):
+ half_dim = self.num_channels // 2
+ exponent = -math.log(10000) * torch.arange(
+ half_dim, dtype=torch.float32, device=timesteps.device
+ )
+ exponent = exponent / (half_dim - 0.0)
+
+ emb = torch.exp(exponent)
+ emb = timesteps[:, None].float() * emb[None, :]
+
+ sin_emb = torch.sin(emb)
+ cos_emb = torch.cos(emb)
+ emb = torch.cat([cos_emb, sin_emb], dim=-1)
+
+ return emb
+
+
+class TimestepEmbedding(nn.Module):
+ def __init__(self, in_features, out_features):
+ super(TimestepEmbedding, self).__init__()
+ self.linear_1 = nn.Linear(in_features, out_features, bias=True)
+ self.act = nn.SiLU()
+ self.linear_2 = nn.Linear(out_features, out_features, bias=True)
+
+ def forward(self, sample):
+ sample = self.linear_1(sample)
+ sample = self.act(sample)
+ sample = self.linear_2(sample)
+
+ return sample
+
+
+class ResnetBlock2D(nn.Module):
+ def __init__(self, in_channels, out_channels, conv_shortcut=True):
+ super(ResnetBlock2D, self).__init__()
+ self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True)
+ self.conv1 = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.time_emb_proj = nn.Linear(1280, out_channels, bias=True)
+ self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True)
+ self.dropout = nn.Dropout(p=0.0, inplace=False)
+ self.conv2 = nn.Conv2d(
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+ self.nonlinearity = nn.SiLU()
+ self.conv_shortcut = None
+ if conv_shortcut:
+ self.conv_shortcut = nn.Conv2d(
+ in_channels, out_channels, kernel_size=1, stride=1
+ )
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ temb = self.nonlinearity(temb)
+ temb = self.time_emb_proj(temb)[:, :, None, None]
+ hidden_states = hidden_states + temb
+ hidden_states = self.norm2(hidden_states)
+
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = input_tensor + hidden_states
+
+ return output_tensor
+
+
+class Attention(nn.Module):
+ def __init__(
+ self, inner_dim, cross_attention_dim=None, num_heads=None, dropout=0.0, processor=None, scale_qk=True
+ ):
+ super(Attention, self).__init__()
+ if num_heads is None:
+ self.head_dim = 64
+ self.num_heads = inner_dim // self.head_dim
+ else:
+ self.num_heads = num_heads
+ self.head_dim = inner_dim // num_heads
+
+ self.scale = self.head_dim**-0.5
+ if cross_attention_dim is None:
+ cross_attention_dim = inner_dim
+ self.to_q = LoRACompatibleLinear(inner_dim, inner_dim, bias=False)
+ self.to_k = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
+ self.to_v = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
+
+ self.to_out = nn.ModuleList(
+ [LoRACompatibleLinear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)]
+ )
+
+ self.scale_qk = scale_qk
+ if processor is None:
+ processor = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
+ )
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ r"""
+ The forward method of the `Attention` class.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ The hidden states of the query.
+ encoder_hidden_states (`torch.Tensor`, *optional*):
+ The hidden states of the encoder.
+ attention_mask (`torch.Tensor`, *optional*):
+ The attention mask to use. If `None`, no mask is applied.
+ **cross_attention_kwargs:
+ Additional keyword arguments to pass along to the cross attention.
+
+ Returns:
+ `torch.Tensor`: The output of the attention layer.
+ """
+ # The `Attention` class can call different attention processors / attention functions
+ # here we simply pass along all tensors to the selected processor class
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
+
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ print(
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+ def orig_forward(self, hidden_states, encoder_hidden_states=None):
+ q = self.to_q(hidden_states)
+ k = (
+ self.to_k(encoder_hidden_states)
+ if encoder_hidden_states is not None
+ else self.to_k(hidden_states)
+ )
+ v = (
+ self.to_v(encoder_hidden_states)
+ if encoder_hidden_states is not None
+ else self.to_v(hidden_states)
+ )
+ b, t, c = q.size()
+
+ q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2)
+ v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2)
+
+ # scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
+ # attn_weights = torch.softmax(scores, dim=-1)
+ # attn_output = torch.matmul(attn_weights, v)
+
+ attn_output = F.scaled_dot_product_attention(
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
+
+ for layer in self.to_out:
+ attn_output = layer(attn_output)
+
+ return attn_output
+
+ def set_processor(self, processor) -> None:
+ r"""
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ print(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False):
+ r"""
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
+ # with PEFT is completed.
+ is_lora_activated = {
+ name: module.lora_layer is not None
+ for name, module in self.named_modules()
+ if hasattr(module, "lora_layer")
+ }
+
+ # 1. if no layer has a LoRA activated we can return the processor as usual
+ if not any(is_lora_activated.values()):
+ return self.processor
+
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
+ is_lora_activated.pop("add_k_proj", None)
+ is_lora_activated.pop("add_v_proj", None)
+ # 2. else it is not possible that only some layers have LoRA activated
+ if not all(is_lora_activated.values()):
+ raise ValueError(
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
+ )
+
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
+ non_lora_processor_cls_name = self.processor.__class__.__name__
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
+
+ hidden_size = self.inner_dim
+
+ # now create a LoRA attention processor from the LoRA layers
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
+ kwargs = {
+ "cross_attention_dim": self.cross_attention_dim,
+ "rank": self.to_q.lora_layer.rank,
+ "network_alpha": self.to_q.lora_layer.network_alpha,
+ "q_rank": self.to_q.lora_layer.rank,
+ "q_hidden_size": self.to_q.lora_layer.out_features,
+ "k_rank": self.to_k.lora_layer.rank,
+ "k_hidden_size": self.to_k.lora_layer.out_features,
+ "v_rank": self.to_v.lora_layer.rank,
+ "v_hidden_size": self.to_v.lora_layer.out_features,
+ "out_rank": self.to_out[0].lora_layer.rank,
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
+ }
+
+ if hasattr(self.processor, "attention_op"):
+ kwargs["attention_op"] = self.processor.attention_op
+
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
+ lora_processor = lora_processor_cls(
+ hidden_size,
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
+ rank=self.to_q.lora_layer.rank,
+ network_alpha=self.to_q.lora_layer.network_alpha,
+ )
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
+
+ # only save if used
+ if self.add_k_proj.lora_layer is not None:
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
+ else:
+ lora_processor.add_k_proj_lora = None
+ lora_processor.add_v_proj_lora = None
+ else:
+ raise ValueError(f"{lora_processor_cls} does not exist.")
+
+ return lora_processor
+
+class GEGLU(nn.Module):
+ def __init__(self, in_features, out_features):
+ super(GEGLU, self).__init__()
+ self.proj = nn.Linear(in_features, out_features * 2, bias=True)
+
+ def forward(self, x):
+ x_proj = self.proj(x)
+ x1, x2 = x_proj.chunk(2, dim=-1)
+ return x1 * torch.nn.functional.gelu(x2)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, in_features, out_features):
+ super(FeedForward, self).__init__()
+
+ self.net = nn.ModuleList(
+ [
+ GEGLU(in_features, out_features * 4),
+ nn.Dropout(p=0.0, inplace=False),
+ nn.Linear(out_features * 4, out_features, bias=True),
+ ]
+ )
+
+ def forward(self, x):
+ for layer in self.net:
+ x = layer(x)
+ return x
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(self, hidden_size):
+ super(BasicTransformerBlock, self).__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
+ self.attn1 = Attention(hidden_size)
+ self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
+ self.attn2 = Attention(hidden_size, 2048)
+ self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
+ self.ff = FeedForward(hidden_size, hidden_size)
+
+ def forward(self, x, encoder_hidden_states=None):
+ residual = x
+
+ x = self.norm1(x)
+ x = self.attn1(x)
+ x = x + residual
+
+ residual = x
+
+ x = self.norm2(x)
+ if encoder_hidden_states is not None:
+ x = self.attn2(x, encoder_hidden_states)
+ else:
+ x = self.attn2(x)
+ x = x + residual
+
+ residual = x
+
+ x = self.norm3(x)
+ x = self.ff(x)
+ x = x + residual
+ return x
+
+
+class Transformer2DModel(nn.Module):
+ def __init__(self, in_channels, out_channels, n_layers):
+ super(Transformer2DModel, self).__init__()
+ self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True)
+ self.proj_in = nn.Linear(in_channels, out_channels, bias=True)
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(out_channels) for _ in range(n_layers)]
+ )
+ self.proj_out = nn.Linear(out_channels, out_channels, bias=True)
+
+ def forward(self, hidden_states, encoder_hidden_states=None):
+ batch, _, height, width = hidden_states.shape
+ res = hidden_states
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
+ batch, height * width, inner_dim
+ )
+ hidden_states = self.proj_in(hidden_states)
+
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states)
+
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, width, inner_dim)
+ .permute(0, 3, 1, 2)
+ .contiguous()
+ )
+
+ return hidden_states + res
+
+
+class Downsample2D(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(Downsample2D, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=2, padding=1
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Upsample2D(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(Upsample2D, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
+ )
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ return self.conv(x)
+
+
+class DownBlock2D(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(DownBlock2D, self).__init__()
+ self.resnets = nn.ModuleList(
+ [
+ ResnetBlock2D(in_channels, out_channels, conv_shortcut=False),
+ ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
+ ]
+ )
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
+
+ def forward(self, hidden_states, temb):
+ output_states = []
+ for module in self.resnets:
+ hidden_states = module(hidden_states, temb)
+ output_states.append(hidden_states)
+
+ hidden_states = self.downsamplers[0](hidden_states)
+ output_states.append(hidden_states)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True):
+ super(CrossAttnDownBlock2D, self).__init__()
+ self.attentions = nn.ModuleList(
+ [
+ Transformer2DModel(out_channels, out_channels, n_layers),
+ Transformer2DModel(out_channels, out_channels, n_layers),
+ ]
+ )
+ self.resnets = nn.ModuleList(
+ [
+ ResnetBlock2D(in_channels, out_channels),
+ ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
+ ]
+ )
+ self.downsamplers = None
+ if has_downsamplers:
+ self.downsamplers = nn.ModuleList(
+ [Downsample2D(out_channels, out_channels)]
+ )
+
+ def forward(self, hidden_states, temb, encoder_hidden_states):
+ output_states = []
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ output_states.append(hidden_states)
+
+ if self.downsamplers is not None:
+ hidden_states = self.downsamplers[0](hidden_states)
+ output_states.append(hidden_states)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(self, in_channels, out_channels, prev_output_channel, n_layers):
+ super(CrossAttnUpBlock2D, self).__init__()
+ self.attentions = nn.ModuleList(
+ [
+ Transformer2DModel(out_channels, out_channels, n_layers),
+ Transformer2DModel(out_channels, out_channels, n_layers),
+ Transformer2DModel(out_channels, out_channels, n_layers),
+ ]
+ )
+ self.resnets = nn.ModuleList(
+ [
+ ResnetBlock2D(prev_output_channel + out_channels, out_channels),
+ ResnetBlock2D(2 * out_channels, out_channels),
+ ResnetBlock2D(out_channels + in_channels, out_channels),
+ ]
+ )
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
+
+ def forward(
+ self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(self, in_channels, out_channels, prev_output_channel):
+ super(UpBlock2D, self).__init__()
+ self.resnets = nn.ModuleList(
+ [
+ ResnetBlock2D(out_channels + prev_output_channel, out_channels),
+ ResnetBlock2D(out_channels * 2, out_channels),
+ ResnetBlock2D(out_channels + in_channels, out_channels),
+ ]
+ )
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
+
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ and getattr(self, "resolution_idx", None)
+ )
+
+ for resnet in self.resnets:
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(self, in_features):
+ super(UNetMidBlock2DCrossAttn, self).__init__()
+ self.attentions = nn.ModuleList(
+ [Transformer2DModel(in_features, in_features, n_layers=10)]
+ )
+ self.resnets = nn.ModuleList(
+ [
+ ResnetBlock2D(in_features, in_features, conv_shortcut=False),
+ ResnetBlock2D(in_features, in_features, conv_shortcut=False),
+ ]
+ )
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNet2DConditionModel(nn.Module):
+ def __init__(self):
+ super(UNet2DConditionModel, self).__init__()
+
+ # This is needed to imitate huggingface config behavior
+ # has nothing to do with the model itself
+ # remove this if you don't use diffuser's pipeline
+ self.config = namedtuple(
+ "config", "in_channels addition_time_embed_dim sample_size"
+ )
+ self.config.in_channels = 4
+ self.config.addition_time_embed_dim = 256
+ self.config.sample_size = 128
+
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1)
+ self.time_proj = Timesteps()
+ self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280)
+ self.add_time_proj = Timesteps(256)
+ self.add_embedding = TimestepEmbedding(in_features=2816, out_features=1280)
+ self.down_blocks = nn.ModuleList(
+ [
+ DownBlock2D(in_channels=320, out_channels=320),
+ CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=2),
+ CrossAttnDownBlock2D(
+ in_channels=640,
+ out_channels=1280,
+ n_layers=10,
+ has_downsamplers=False,
+ ),
+ ]
+ )
+ self.up_blocks = nn.ModuleList(
+ [
+ CrossAttnUpBlock2D(
+ in_channels=640,
+ out_channels=1280,
+ prev_output_channel=1280,
+ n_layers=10,
+ ),
+ CrossAttnUpBlock2D(
+ in_channels=320,
+ out_channels=640,
+ prev_output_channel=1280,
+ n_layers=2,
+ ),
+ UpBlock2D(in_channels=320, out_channels=320, prev_output_channel=640),
+ ]
+ )
+ self.mid_block = UNetMidBlock2DCrossAttn(1280)
+ self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True)
+ self.conv_act = nn.SiLU()
+ self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1)
+
+ def forward(
+ self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs
+ ):
+ # Implement the forward pass through the model
+ timesteps = timesteps.expand(sample.shape[0])
+ t_emb = self.time_proj(timesteps).to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb)
+
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ time_ids = added_cond_kwargs.get("time_ids")
+
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb
+
+ sample = self.conv_in(sample)
+
+ # 3. down
+ s0 = sample
+ sample, [s1, s2, s3] = self.down_blocks[0](
+ sample,
+ temb=emb,
+ )
+
+ sample, [s4, s5, s6] = self.down_blocks[1](
+ sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ sample, [s7, s8] = self.down_blocks[2](
+ sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ # 4. mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states
+ )
+
+ # 5. up
+ sample = self.up_blocks[0](
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=[s6, s7, s8],
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ sample = self.up_blocks[1](
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=[s3, s4, s5],
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ sample = self.up_blocks[2](
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=[s0, s1, s2],
+ )
+
+ # 6. post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ return [sample]
\ No newline at end of file
diff --git a/module/unet/unet_2d_ZeroSFT.py b/module/unet/unet_2d_ZeroSFT.py
new file mode 100644
index 0000000000000000000000000000000000000000..c91286fcf0d119c5e04c202b75a27cf5e8fdbd59
--- /dev/null
+++ b/module/unet/unet_2d_ZeroSFT.py
@@ -0,0 +1,1397 @@
+# Copy from diffusers.models.unets.unet_2d_condition.py
+
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ GLIGENTextBoundingboxProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from .unet_2d_ZeroSFT_blocks import (
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+
+class ZeroConv(nn.Module):
+ def __init__(self, label_nc, norm_nc, mask=False):
+ super().__init__()
+ self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
+ self.mask = mask
+
+ def forward(self, c, h, h_ori=None):
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
+ if not self.mask:
+ h = h + self.zero_conv(c)
+ else:
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
+ if h_ori is not None:
+ h = torch.cat([h_ori, h], dim=1)
+ return h
+
+
+class ZeroSFT(nn.Module):
+ def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
+ super().__init__()
+
+ # param_free_norm_type = str(parsed.group(1))
+ ks = 3
+ pw = ks // 2
+
+ self.mask = mask
+ self.norm = norm
+ self.pre_concat = bool(concat_channels != 0)
+ if self.norm:
+ self.param_free_norm = torch.nn.GroupNorm(num_groups=32, num_channels=norm_nc + concat_channels)
+ else:
+ self.param_free_norm = nn.Identity()
+
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
+ nn.SiLU()
+ )
+ self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
+ self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
+
+ self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
+
+ def forward(self, down_block_res_samples, h_ori=None, control_scale=1.0, mask=False):
+ mask = mask or self.mask
+ assert mask is False
+ if self.pre_concat:
+ assert h_ori is not None
+
+ c,h = down_block_res_samples
+ if h_ori is not None:
+ h_raw = torch.cat([h_ori, h], dim=1)
+ else:
+ h_raw = h
+
+ if self.mask:
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
+ else:
+ h = h + self.zero_conv(c)
+ if h_ori is not None and self.pre_concat:
+ h = torch.cat([h_ori, h], dim=1)
+ actv = self.mlp_shared(c)
+ gamma = self.zero_mul(actv)
+ beta = self.zero_add(actv)
+ if self.mask:
+ gamma = gamma * torch.zeros_like(gamma)
+ beta = beta * torch.zeros_like(beta)
+ # h = h + self.param_free_norm(h) * gamma + beta
+ h = self.param_free_norm(h) * (gamma + 1) + beta
+ if h_ori is not None and not self.pre_concat:
+ h = torch.cat([h_ori, h], dim=1)
+ return h * control_scale + h_raw * (1 - control_scale)
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DZeroSFTModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ self._check_config(
+ down_block_types=down_block_types,
+ up_block_types=up_block_types,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
+ time_embedding_type,
+ block_out_channels=block_out_channels,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ time_embedding_dim=time_embedding_dim,
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ self._set_encoder_hid_proj(
+ encoder_hid_dim_type,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ )
+
+ # class embedding
+ self._set_class_embedding(
+ class_embed_type,
+ act_fn=act_fn,
+ num_class_embeds=num_class_embeds,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ timestep_input_dim=timestep_input_dim,
+ )
+
+ self._set_add_embedding(
+ addition_embed_type,
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
+ addition_time_embed_dim=addition_time_embed_dim,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ )
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ temb_channels=blocks_time_embed_dim,
+ in_channels=block_out_channels[-1],
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ output_scale_factor=mid_block_scale_factor,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ num_attention_heads=num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[-1],
+ dropout=dropout,
+ )
+ self.mid_zero_SFT = ZeroSFT(block_out_channels[-1],block_out_channels[-1],0)
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
+
+ def _check_config(
+ self,
+ down_block_types: Tuple[str],
+ up_block_types: Tuple[str],
+ only_cross_attention: Union[bool, Tuple[bool]],
+ block_out_channels: Tuple[int],
+ layers_per_block: Union[int, Tuple[int]],
+ cross_attention_dim: Union[int, Tuple[int]],
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
+ reverse_transformer_layers_per_block: bool,
+ attention_head_dim: int,
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
+ ):
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ def _set_time_proj(
+ self,
+ time_embedding_type: str,
+ block_out_channels: int,
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ time_embedding_dim: int,
+ ) -> Tuple[int, int]:
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ return time_embed_dim, timestep_input_dim
+
+ def _set_encoder_hid_proj(
+ self,
+ encoder_hid_dim_type: Optional[str],
+ cross_attention_dim: Union[int, Tuple[int]],
+ encoder_hid_dim: Optional[int],
+ ):
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ def _set_class_embedding(
+ self,
+ class_embed_type: Optional[str],
+ act_fn: str,
+ num_class_embeds: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ timestep_input_dim: int,
+ ):
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ def _set_add_embedding(
+ self,
+ addition_embed_type: str,
+ addition_embed_type_num_heads: int,
+ addition_time_embed_dim: Optional[int],
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ cross_attention_dim: Optional[int],
+ encoder_hid_dim: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ ):
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = GLIGENTextBoundingboxProjection(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is ๐งช experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is ๐งช experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def unload_lora(self):
+ """Unloads LoRA weights."""
+ deprecate(
+ "unload_lora",
+ "0.28.0",
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
+ )
+ for module in self.modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+
+ def get_time_embed(
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
+ ) -> Optional[torch.Tensor]:
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ return t_emb
+
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ class_emb = None
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+ return class_emb
+
+ def get_aug_embed(
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> Optional[torch.Tensor]:
+ aug_emb = None
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb = self.add_embedding(image_embs, hint)
+ return aug_emb
+
+ def process_encoder_hidden_states(
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> torch.Tensor:
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
+ return encoder_hidden_states
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
+ if cross_attention_kwargs is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_additional_residual, down_block_res_sample in zip(
+ down_block_additional_residuals, down_block_res_samples
+ ):
+ down_block_res_sample_tuple = (down_block_additional_residual, down_block_res_sample)
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample_tuple,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = self.mid_zero_SFT((mid_block_additional_residual, sample),)
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/module/unet/unet_2d_ZeroSFT_blocks.py b/module/unet/unet_2d_ZeroSFT_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..66b17c7b0a0f0b0dab77d05d0a10a1d564bae763
--- /dev/null
+++ b/module/unet/unet_2d_ZeroSFT_blocks.py
@@ -0,0 +1,3862 @@
+# Copy from diffusers.models.unet.unet_2d_blocks.py
+
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import deprecate, is_torch_version, logging
+from diffusers.utils.torch_utils import apply_freeu
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
+from diffusers.models.normalization import AdaGroupNorm
+from diffusers.models.resnet import (
+ Downsample2D,
+ FirDownsample2D,
+ FirUpsample2D,
+ KDownsample2D,
+ KUpsample2D,
+ ResnetBlock2D,
+ ResnetBlockCondNorm2D,
+ Upsample2D,
+)
+from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
+from diffusers.models.transformers.transformer_2d import Transformer2DModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_down_block(
+ down_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ add_downsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ downsample_padding: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ downsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "ResnetDownsampleBlock2D":
+ return ResnetDownsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ if add_downsample is False:
+ downsample_type = None
+ else:
+ downsample_type = downsample_type or "conv" # default to 'conv'
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ downsample_type=downsample_type,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ )
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
+ return SimpleCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnDownEncoderBlock2D":
+ return AttnDownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "KDownBlock2D":
+ return KDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif down_block_type == "KCrossAttnDownBlock2D":
+ return KCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ add_self_attention=True if not add_downsample else False,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_mid_block(
+ mid_block_type: str,
+ temb_channels: int,
+ in_channels: int,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resnet_groups: int,
+ output_scale_factor: float = 1.0,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ mid_block_only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = 1,
+ dropout: float = 0.0,
+):
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ return UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ resnet_groups=resnet_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ return UNetMidBlock2DSimpleCrossAttn(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ return UNetMidBlock2D(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ num_layers=0,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ elif mid_block_type is None:
+ return None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+
+def get_up_block(
+ up_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ add_upsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resolution_idx: Optional[int] = None,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ upsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+) -> nn.Module:
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "ResnetUpsampleBlock2D":
+ return ResnetUpsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ )
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
+ return SimpleCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ if add_upsample is False:
+ upsample_type = None
+ else:
+ upsample_type = upsample_type or "conv" # default to 'conv'
+
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ upsample_type=upsample_type,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "AttnUpDecoderBlock2D":
+ return AttnUpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "KUpBlock2D":
+ return KUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "KCrossAttnUpBlock2D":
+ return KCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ )
+
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+
+class ZeroConv(nn.Module):
+ def __init__(self, label_nc, norm_nc, mask=False):
+ super().__init__()
+ self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
+ self.mask = mask
+
+ def forward(self, c, h, h_ori=None):
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
+ if not self.mask:
+ h = h + self.zero_conv(c)
+ else:
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
+ if h_ori is not None:
+ h = torch.cat([h_ori, h], dim=1)
+ return h
+
+
+class ZeroSFT(nn.Module):
+ def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
+ super().__init__()
+
+ # param_free_norm_type = str(parsed.group(1))
+ ks = 3
+ pw = ks // 2
+
+ self.mask = mask
+ self.norm = norm
+ self.pre_concat = bool(concat_channels != 0)
+ if self.norm:
+ self.param_free_norm = torch.nn.GroupNorm(num_groups=32, num_channels=norm_nc + concat_channels)
+ else:
+ self.param_free_norm = nn.Identity()
+
+ nhidden = 128
+
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
+ nn.SiLU()
+ )
+ self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
+ self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
+
+ self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
+
+ def forward(self, down_block_res_samples, h_ori=None, control_scale=1.0, mask=False):
+ mask = mask or self.mask
+ assert mask is False
+ if self.pre_concat:
+ assert h_ori is not None
+
+ c,h = down_block_res_samples
+ if h_ori is not None:
+ h_raw = torch.cat([h_ori, h], dim=1)
+ else:
+ h_raw = h
+
+ if self.mask:
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
+ else:
+ h = h + self.zero_conv(c)
+ if h_ori is not None and self.pre_concat:
+ h_ori_c = h_ori.shape[1]
+ h_c = h.shape[1]
+ h = torch.cat([h_ori, h], dim=1)
+ actv = self.mlp_shared(c)
+ gamma = self.zero_mul(actv)
+ beta = self.zero_add(actv)
+ if self.mask:
+ gamma = gamma * torch.zeros_like(gamma)
+ beta = beta * torch.zeros_like(beta)
+ # gamma_ori, gamma_res = torch.split(gamma, [h_ori_c, h_c], dim=1)
+ # beta_ori, beta_res = torch.split(beta, [h_ori_c, h_c], dim=1)
+ # print(gamma_ori.mean(), gamma_res.mean(), beta_ori.mean(), beta_res.mean())
+ # h = h + self.param_free_norm(h) * gamma + beta
+ h = self.param_free_norm(h) * (gamma + 1) + beta
+ # sample_ori, sample_res = torch.split(h, [h_ori_c, h_c], dim=1)
+ # print(sample_ori.mean(), sample_res.mean())
+ if h_ori is not None and not self.pre_concat:
+ h = torch.cat([h_ori, h], dim=1)
+ return h * control_scale + h_raw * (1 - control_scale)
+
+
+class AutoencoderTinyBlock(nn.Module):
+ """
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
+ blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ out_channels (`int`): The number of output channels.
+ act_fn (`str`):
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
+
+ Returns:
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
+ `out_channels`.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
+ super().__init__()
+ act_fn = get_activation(act_fn)
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+ act_fn,
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+ act_fn,
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+ )
+ self.skip = (
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ if in_channels != out_channels
+ else nn.Identity()
+ )
+ self.fuse = nn.ReLU()
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+class UNetMidBlock2D(nn.Module):
+ """
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ temb_channels (`int`): The number of temporal embedding channels.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
+ model on tasks with long-range temporal dependencies.
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
+ resnet_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use in the group normalization layers of the resnet blocks.
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
+ Whether to use pre-normalization for the resnet blocks.
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
+ attention_head_dim (`int`, *optional*, defaults to 1):
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
+ the number of input channels.
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
+
+ Returns:
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
+ in_channels, height, width)`.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ attn_groups: Optional[int] = None,
+ resnet_pre_norm: bool = True,
+ add_attention: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ self.add_attention = add_attention
+
+ if attn_groups is None:
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
+
+ # there is always at least one resnet
+ if resnet_time_scale_shift == "spatial":
+ resnets = [
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ ]
+ else:
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
+ )
+ attention_head_dim = in_channels
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=attn_groups,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ out_channels: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_groups_out: Optional[int] = None,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # support for variable transformer layers per block
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ resnet_groups_out = resnet_groups_out or resnet_groups
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ groups_out=resnet_groups_out,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for i in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups_out,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups_out,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = attn(
+ hidden_states,
+ timestep=temb,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = attn(
+ hidden_states,
+ timestep=temb,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DSimpleCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ self.attention_head_dim = attention_head_dim
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ self.num_heads = in_channels // self.attention_head_dim
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ cross_attention_dim=in_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ # attn
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ # resnet
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ downsample_type: str = "conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ self.downsample_type = downsample_type
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if downsample_type == "conv":
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ elif downsample_type == "resnet":
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ if self.downsample_type == "resnet":
+ hidden_states = downsampler(hidden_states, temb=temb)
+ else:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ add_downsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ timestep=temb,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ timestep=temb,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None)
+ hidden_states = attn(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample: Optional[torch.FloatTensor] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class ResnetDownsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ skip_time_act: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class SimpleCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ resnets = []
+ attentions = []
+
+ self.attention_head_dim = attention_head_dim
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class KDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ add_downsample: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ # YiYi's comments- might be able to use FirDownsample2D, look into details later
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class KCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ cross_attention_dim: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_group_size: int = 32,
+ add_downsample: bool = True,
+ attention_head_dim: int = 64,
+ add_self_attention: bool = False,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ out_channels,
+ out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ group_size=resnet_group_size,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.downsamplers is None:
+ output_states += (None,)
+ else:
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ upsample_type: str = "conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.upsample_type = upsample_type
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if upsample_type == "conv":
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ elif upsample_type == "resnet":
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ if self.upsample_type == "resnet":
+ hidden_states = upsampler(hidden_states, temb=temb)
+ else:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ zero_SFTs = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ zero_SFTs.append(
+ ZeroSFT(
+ res_skip_channels,
+ res_skip_channels,
+ resnet_in_channels
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.zero_SFTs = nn.ModuleList(zero_SFTs)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ if cross_attention_kwargs is not None:
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ for resnet, attn, zero_SFT in zip(self.resnets, self.attentions, self.zero_SFTs):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ if isinstance(res_hidden_states, tuple):
+ # ZeroSFT
+ hidden_states = zero_SFT(res_hidden_states, hidden_states)
+ else:
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states[1]+res_hidden_states[0],
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ timestep=temb,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ timestep=temb,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ zero_SFTs = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ zero_SFTs.append(
+ ZeroSFT(
+ res_skip_channels,
+ res_skip_channels,
+ resnet_in_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.zero_SFTs = nn.ModuleList(zero_SFTs)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ for resnet, zero_SFT in zip(self.resnets, self.zero_SFTs):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ if isinstance(res_hidden_states, tuple):
+ # ZeroSFT
+ hidden_states = zero_SFT(res_hidden_states, hidden_states)
+ else:
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states[1]+res_hidden_states[0],
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=temb)
+ hidden_states = attn(hidden_states, temb=temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if attention_head_dim is None:
+ logger.warning(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample=None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = self.attentions[0](hidden_states)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_upsample: bool = True,
+ upsample_padding: int = 1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample=None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb)
+
+ return hidden_states, skip_sample
+
+
+class ResnetUpsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ skip_time_act: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class SimpleCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # resnet
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class KUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int,
+ dropout: float = 0.0,
+ num_layers: int = 5,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: Optional[int] = 32,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ k_in_channels = 2 * out_channels
+ k_out_channels = in_channels
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=k_out_channels if (i == num_layers - 1) else out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+ deprecate("scale", "1.0.0", deprecation_message)
+
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class KCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ attention_head_dim: int = 1, # attention dim_head
+ cross_attention_dim: int = 768,
+ add_upsample: bool = True,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ is_first_block = in_channels == out_channels == temb_channels
+ is_middle_block = in_channels != out_channels
+ add_self_attention = True if is_first_block else False
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ # in_channels, and out_channels for the block (k-unet)
+ k_in_channels = out_channels if is_first_block else 2 * out_channels
+ k_out_channels = in_channels
+
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ if is_middle_block and (i == num_layers - 1):
+ conv_2d_out_channels = k_out_channels
+ else:
+ conv_2d_out_channels = None
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ conv_2d_out_channels=conv_2d_out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ k_out_channels if (i == num_layers - 1) else out_channels,
+ k_out_channels // attention_head_dim
+ if (i == num_layers - 1)
+ else out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+# can potentially later be renamed to `No-feed-forward` attention
+class KAttentionBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Configure if the attention layers should contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to upcast the attention computation to `float32`.
+ temb_channels (`int`, *optional*, defaults to 768):
+ The number of channels in the token embedding.
+ add_self_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add self-attention to the block.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ group_size (`int`, *optional*, defaults to 32):
+ The number of groups to separate the channels into for group normalization.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout: float = 0.0,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ upcast_attention: bool = False,
+ temb_channels: int = 768, # for ada_group_norm
+ add_self_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ group_size: int = 32,
+ ):
+ super().__init__()
+ self.add_self_attention = add_self_attention
+
+ # 1. Self-Attn
+ if add_self_attention:
+ self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ cross_attention_norm=None,
+ )
+
+ # 2. Cross-Attn
+ self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+
+ def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
+
+ def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ # TODO: mark emb as non-optional (self.norm2 requires it).
+ # requires assessing impact of change to positional param interface.
+ emb: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ if cross_attention_kwargs.get("scale", None) is not None:
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
+
+ # 1. Self-Attention
+ if self.add_self_attention:
+ norm_hidden_states = self.norm1(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention/None
+ norm_hidden_states = self.norm2(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ return hidden_states
diff --git a/pipelines/sdxl_instantir.py b/pipelines/sdxl_instantir.py
new file mode 100644
index 0000000000000000000000000000000000000000..af6b562cead64a00c84cf1afc36de153d6e65511
--- /dev/null
+++ b/pipelines/sdxl_instantir.py
@@ -0,0 +1,1740 @@
+# Copyright 2024 The InstantX Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from diffusers.utils.import_utils import is_invisible_watermark_available
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+ AttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers, LCMScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+ convert_unet_state_dict_to_peft
+)
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+from peft import LoraConfig, set_peft_model_state_dict
+from module.aggregator import Aggregator
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install diffusers pillow transformers accelerate
+ >>> import torch
+ >>> from PIL import Image
+
+ >>> from diffusers import DDPMScheduler
+ >>> from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
+
+ >>> from module.ip_adapter.utils import load_adapter_to_pipe
+ >>> from pipelines.sdxl_instantir import InstantIRPipeline
+
+ >>> # download models under ./models
+ >>> dcp_adapter = f'./models/adapter.pt'
+ >>> previewer_lora_path = f'./models'
+ >>> instantir_path = f'./models/aggregator.pt'
+
+ >>> # load pretrained models
+ >>> pipe = InstantIRPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
+ ... )
+ >>> # load adapter
+ >>> load_adapter_to_pipe(
+ ... pipe,
+ ... dcp_adapter,
+ ... image_encoder_or_path = 'facebook/dinov2-large',
+ ... )
+ >>> # load previewer lora
+ >>> pipe.prepare_previewers(previewer_lora_path)
+ >>> pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
+ >>> lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
+
+ >>> # load aggregator weights
+ >>> pretrained_state_dict = torch.load(instantir_path)
+ >>> pipe.aggregator.load_state_dict(pretrained_state_dict)
+
+ >>> # send to GPU and fp16
+ >>> pipe.to(device="cuda", dtype=torch.float16)
+ >>> pipe.aggregator.to(device="cuda", dtype=torch.float16)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # load a broken image
+ >>> low_quality_image = Image.open('path/to/your-image').convert("RGB")
+
+ >>> # restoration
+ >>> image = pipe(
+ ... image=low_quality_image,
+ ... previewer_scheduler=lcm_scheduler,
+ ... ).images[0]
+ ```
+"""
+
+LCM_LORA_MODULES = [
+ "to_q",
+ "to_k",
+ "to_v",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+]
+PREVIEWER_LORA_MODULES = [
+ "to_q",
+ "to_kv",
+ "0.to_out",
+ "attn1.to_k",
+ "attn1.to_v",
+ "to_k_ip",
+ "to_v_ip",
+ "ln_k_ip.linear",
+ "ln_v_ip.linear",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+]
+
+
+def remove_attn2(model):
+ def recursive_find_module(name, module):
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
+ elif "resnets" in name: return
+ if hasattr(module, "attn2"):
+ setattr(module, "attn2", None)
+ setattr(module, "norm2", None)
+ return
+ for sub_name, sub_module in module.named_children():
+ recursive_find_module(f"{name}.{sub_name}", sub_module)
+
+ for name, module in model.named_children():
+ recursive_find_module(name, module)
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class InstantIRPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14)).
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
+ Second frozen text-encoder
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co./laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
+ additional conditioning.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings should always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
+ watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
+ watermarker is used.
+ """
+
+ # leave controlnet out on purpose because it iterates with unet
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "feature_extractor",
+ "image_encoder",
+ ]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ aggregator: Aggregator = None,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ feature_extractor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ ):
+ super().__init__()
+
+ if aggregator is None:
+ aggregator = Aggregator.from_unet(unet)
+ remove_attn2(aggregator)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ aggregator=aggregator,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.control_image_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=True
+ )
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+
+ def prepare_previewers(self, previewer_lora_path: str, use_lcm=False):
+ if use_lcm:
+ lora_state_dict, alpha_dict = self.lora_state_dict(
+ previewer_lora_path,
+ )
+ else:
+ lora_state_dict, alpha_dict = self.lora_state_dict(
+ previewer_lora_path,
+ weight_name="previewer_lora_weights.bin"
+ )
+ unet_state_dict = {
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ }
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ lora_state_dict = dict()
+ for k, v in unet_state_dict.items():
+ if "ip" in k:
+ k = k.replace("attn2", "attn2.processor")
+ lora_state_dict[k] = v
+ else:
+ lora_state_dict[k] = v
+ if alpha_dict:
+ lora_alpha = next(iter(alpha_dict.values()))
+ else:
+ lora_alpha = 1
+ logger.info(f"use lora alpha {lora_alpha}")
+ lora_config = LoraConfig(
+ r=64,
+ target_modules=LCM_LORA_MODULES if use_lcm else PREVIEWER_LORA_MODULES,
+ lora_alpha=lora_alpha,
+ lora_dropout=0.0,
+ )
+
+ adapter_name = "lcm" if use_lcm else "previewer"
+ self.unet.add_adapter(lora_config, adapter_name)
+ incompatible_keys = set_peft_model_state_dict(self.unet, lora_state_dict, adapter_name=adapter_name)
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
+ if unexpected_keys:
+ raise ValueError(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ self.unet.disable_adapters()
+
+ return lora_alpha
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ if isinstance(self.image_encoder, CLIPVisionModelWithProjection):
+ # CLIP image encoder.
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+ else:
+ # DINO image encoder.
+ image_embeds = self.image_encoder(image).last_hidden_state
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = self.image_encoder(
+ torch.zeros_like(image)
+ ).last_hidden_state
+ uncond_image_embeds = uncond_image_embeds.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ if isinstance(ip_adapter_image[0], list):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+ else:
+ logger.warning(
+ f"Got {len(ip_adapter_image)} images for {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ " By default, these images will be sent to each IP-Adapter. If this is not your use-case, please specify `ip_adapter_image` as a list of image-list, with"
+ f" length equals to the number of IP-Adapters."
+ )
+ ip_adapter_image = [ip_adapter_image] * len(self.unet.encoder_hid_proj.image_projection_layers)
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = isinstance(self.image_encoder, CLIPVisionModelWithProjection) and not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * (num_images_per_prompt//single_image_embeds.shape[0]), dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * (num_images_per_prompt//single_negative_image_embeds.shape[0]), dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ controlnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.aggregator, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.aggregator, Aggregator)
+ or is_compiled
+ and isinstance(self.aggregator._orig_mod, Aggregator)
+ ):
+ self.check_image(image, prompt, prompt_embeds)
+ else:
+ assert False
+
+ if control_guidance_start >= control_guidance_end:
+ raise ValueError(
+ f"control guidance start: {control_guidance_start} cannot be larger or equal to control guidance end: {control_guidance_end}."
+ )
+ if control_guidance_start < 0.0:
+ raise ValueError(f"control guidance start: {control_guidance_start} can't be smaller than 0.")
+ if control_guidance_end > 1.0:
+ raise ValueError(f"control guidance end: {control_guidance_end} can't be larger than 1.0.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
+ def check_image(self, image, prompt, prompt_embeds):
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
+ def prepare_image(
+ self,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ ):
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ return image
+
+ @torch.no_grad()
+ def init_latents(self, latents, generator, timestep):
+ noise = torch.randn(latents.shape, generator=generator, device=self.vae.device, dtype=self.vae.dtype, layout=torch.strided)
+ bsz = latents.shape[0]
+ print(f"init latent at {timestep}")
+ timestep = torch.tensor([timestep]*bsz, device=self.vae.device)
+ # Note that the latents will be scaled aleady by scheduler.add_noise
+ latents = self.scheduler.add_noise(latents, noise, timestep)
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.FloatTensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ image: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 30,
+ timesteps: List[int] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 7.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ save_preview_row: bool = False,
+ init_latents_with_lq: bool = True,
+ multistep_restore: bool = False,
+ adastep_restore: bool = False,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ controlnet_conditioning_scale: float = 1.0,
+ control_guidance_start: float = 0.0,
+ control_guidance_end: float = 1.0,
+ preview_start: float = 0.0,
+ preview_end: float = 1.0,
+ original_size: Tuple[int, int] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Tuple[int, int] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ previewer_scheduler: KarrasDiffusionSchedulers = None,
+ reference_latents: Optional[torch.FloatTensor] = None,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
+ input to a single ControlNet.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (ฮท) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, pooled text embeddings are generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
+ argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned containing the output images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ aggregator = self.aggregator._orig_mod if is_compiled_module(self.aggregator) else self.aggregator
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image] if ip_adapter_image is not None else [image]
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ negative_pooled_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ if not isinstance(image, PIL.Image.Image):
+ batch_size = len(image)
+ else:
+ batch_size = 1
+ prompt = [prompt] * batch_size
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1)
+ else:
+ batch_size = prompt_embeds.shape[0]
+ assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1)
+
+ device = self._execution_device
+
+ # 3.1 Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 3.2 Encode ip_adapter_image
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare image
+ image = self.prepare_image(
+ image=image,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=aggregator.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ )
+ height, width = image.shape[-2:]
+ if image.shape[1] != 4:
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+ if needs_upcasting:
+ image = image.float()
+ self.vae.to(dtype=torch.float32)
+ image = self.vae.encode(image).latent_dist.sample()
+ image = image * self.vae.config.scaling_factor
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ image = image.to(dtype=torch.float16)
+ else:
+ height = int(height * self.vae_scale_factor)
+ width = int(width * self.vae_scale_factor)
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 6. Prepare latent variables
+ if init_latents_with_lq:
+ latents = self.init_latents(image, generator, timesteps[0])
+ else:
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Create tensor stating which controlnets to keep
+ controlnet_keep = []
+ previewing = []
+ for i in range(len(timesteps)):
+ keeps = 1.0 - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
+ controlnet_keep.append(keeps)
+ use_preview = 1.0 - float(i / len(timesteps) < preview_start or (i + 1) / len(timesteps) > preview_end)
+ previewing.append(use_preview)
+ if isinstance(controlnet_conditioning_scale, list):
+ assert len(controlnet_conditioning_scale) == len(timesteps), f"{len(controlnet_conditioning_scale)} controlnet scales do not match number of sampling steps {len(timesteps)}"
+ else:
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet_keep)
+
+ # 7.2 Prepare added time ids & embeddings
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+ image = torch.cat([image] * 2, dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_aggregator_compiled = is_compiled_module(self.aggregator)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+ previewer_mean = torch.zeros_like(latents)
+ unet_mean = torch.zeros_like(latents)
+ preview_factor = torch.ones(
+ (latents.shape[0], *((1,) * (len(latents.shape) - 1))), dtype=latents.dtype, device=latents.device
+ )
+
+ self._num_timesteps = len(timesteps)
+ preview_row = []
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_aggregator_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ prev_t = t
+ unet_model_input = latent_model_input
+
+ added_cond_kwargs = {
+ "text_embeds": add_text_embeds,
+ "time_ids": add_time_ids,
+ "image_embeds": image_embeds
+ }
+ aggregator_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ # prepare time_embeds in advance as adapter input
+ cross_attention_t_emb = self.unet.get_time_embed(sample=latent_model_input, timestep=t)
+ cross_attention_emb = self.unet.time_embedding(cross_attention_t_emb, timestep_cond)
+ cross_attention_aug_emb = None
+
+ cross_attention_aug_emb = self.unet.get_aug_embed(
+ emb=cross_attention_emb,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=added_cond_kwargs
+ )
+
+ cross_attention_emb = cross_attention_emb + cross_attention_aug_emb if cross_attention_aug_emb is not None else cross_attention_emb
+
+ if self.unet.time_embed_act is not None:
+ cross_attention_emb = self.unet.time_embed_act(cross_attention_emb)
+
+ current_cross_attention_kwargs = {"temb": cross_attention_emb}
+ if cross_attention_kwargs is not None:
+ for k,v in cross_attention_kwargs.items():
+ current_cross_attention_kwargs[k] = v
+ self._cross_attention_kwargs = current_cross_attention_kwargs
+
+ # adaptive restoration factors
+ adaRes_scale = preview_factor.to(latent_model_input.dtype).clamp(0.0, controlnet_conditioning_scale[i])
+ cond_scale = adaRes_scale * controlnet_keep[i]
+ cond_scale = torch.cat([cond_scale] * 2) if self.do_classifier_free_guidance else cond_scale
+
+ if (cond_scale>0.1).sum().item() > 0:
+ if previewing[i] > 0:
+ # preview with LCM
+ self.unet.enable_adapters()
+ preview_noise = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+ preview_latent = previewer_scheduler.step(
+ preview_noise,
+ t.to(dtype=torch.int64),
+ # torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
+ latent_model_input, # scaled latents here for compatibility
+ return_dict=False
+ )[0]
+ self.unet.disable_adapters()
+
+ if self.do_classifier_free_guidance:
+ preview_row.append(preview_latent.chunk(2)[1].to('cpu'))
+ else:
+ preview_row.append(preview_latent.to('cpu'))
+ # Prepare 2nd order step.
+ if multistep_restore and i+1 < len(timesteps):
+ noise_preview = preview_noise.chunk(2)[1] if self.do_classifier_free_guidance else preview_noise
+ first_step = self.scheduler.step(
+ noise_preview, t, latents,
+ **extra_step_kwargs, return_dict=True, step_forward=False
+ )
+ prev_t = timesteps[i + 1]
+ unet_model_input = torch.cat([first_step.prev_sample] * 2) if self.do_classifier_free_guidance else first_step.prev_sample
+ unet_model_input = self.scheduler.scale_model_input(unet_model_input, prev_t, heun_step=True)
+
+ elif reference_latents is not None:
+ preview_latent = torch.cat([reference_latents] * 2) if self.do_classifier_free_guidance else reference_latents
+ else:
+ preview_latent = image
+
+ # Add fresh noise
+ # preview_noise = torch.randn_like(preview_latent)
+ # preview_latent = self.scheduler.add_noise(preview_latent, preview_noise, t)
+
+ preview_latent=preview_latent.to(dtype=next(aggregator.parameters()).dtype)
+
+ # Aggregator inference
+ down_block_res_samples, mid_block_res_sample = aggregator(
+ image,
+ prev_t,
+ encoder_hidden_states=prompt_embeds,
+ controlnet_cond=preview_latent,
+ # conditioning_scale=cond_scale,
+ added_cond_kwargs=aggregator_added_cond_kwargs,
+ return_dict=False,
+ )
+
+ # aggregator features scaling
+ down_block_res_samples = [sample*cond_scale for sample in down_block_res_samples]
+ mid_block_res_sample = mid_block_res_sample*cond_scale
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ unet_model_input,
+ prev_t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ unet_step = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
+ latents = unet_step.prev_sample
+
+ # Update adaRes factors
+ unet_pred_latent = unet_step.pred_original_sample
+
+ # Adaptive restoration.
+ if adastep_restore:
+ pred_x0_l2 = ((preview_latent[latents.shape[0]:].float()-unet_pred_latent.float())).pow(2).sum(dim=(1,2,3))
+ previewer_l2 = ((preview_latent[latents.shape[0]:].float()-previewer_mean.float())).pow(2).sum(dim=(1,2,3))
+ # unet_l2 = ((unet_pred_latent.float()-unet_mean.float())).pow(2).sum(dim=(1,2,3)).sqrt()
+ # l2_error = (((preview_latent[latents.shape[0]:]-previewer_mean) - (unet_pred_latent-unet_mean))).pow(2).mean(dim=(1,2,3))
+ # preview_error = torch.nn.functional.cosine_similarity(preview_latent[latents.shape[0]:].reshape(latents.shape[0], -1), unet_pred_latent.reshape(latents.shape[0],-1))
+ previewer_mean = preview_latent[latents.shape[0]:]
+ unet_mean = unet_pred_latent
+ preview_factor = (pred_x0_l2 / previewer_l2).reshape(-1, 1, 1, 1)
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ if save_preview_row:
+ preview_image_row = []
+ if needs_upcasting:
+ self.upcast_vae()
+ for preview_latents in preview_row:
+ preview_latents = preview_latents.to(device=self.device, dtype=next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype)
+ )
+ preview_latents = preview_latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ preview_latents = preview_latents / self.vae.config.scaling_factor
+
+ preview_image = self.vae.decode(preview_latents, return_dict=False)[0]
+ preview_image = self.image_processor.postprocess(preview_image, output_type=output_type)
+ preview_image_row.append(preview_image)
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ if save_preview_row:
+ return (image, preview_image_row)
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/pipelines/stage1_sdxl_pipeline.py b/pipelines/stage1_sdxl_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..be429d9e1baa849a6cf57932f665dd6a1756157c
--- /dev/null
+++ b/pipelines/stage1_sdxl_pipeline.py
@@ -0,0 +1,1283 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import (
+ FromSingleFileMixin,
+ IPAdapterMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from ...models.attention_processor import (
+ AttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ LoRAAttnProcessor2_0,
+ LoRAXFormersAttnProcessor,
+ XFormersAttnProcessor,
+)
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...schedulers import KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ is_invisible_watermark_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from .pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+ from .watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionXLPipeline
+
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class StableDiffusionXLPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ FromSingleFileMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ TextualInversionLoaderMixin,
+ IPAdapterMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion XL.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
+ [CLIP](https://huggingface.co./docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co./openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+ [CLIP](https://huggingface.co./docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+ specifically the
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co./laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+ variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`CLIPTokenizer`):
+ Second Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co./docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+ `stabilityai/stable-diffusion-xl-base-1-0`.
+ add_watermarker (`bool`, *optional*):
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+ watermarker will be used.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+ _optional_components = [
+ "tokenizer",
+ "tokenizer_2",
+ "text_encoder",
+ "text_encoder_2",
+ "image_encoder",
+ "feature_extractor",
+ ]
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ "add_text_embeds",
+ "add_time_ids",
+ "negative_pooled_prompt_embeds",
+ "negative_add_time_ids",
+ ]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ text_encoder_2: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ tokenizer_2: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ force_zeros_for_empty_prompt: bool = True,
+ add_watermarker: Optional[bool] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+
+ self.default_sample_size = self.unet.config.sample_size
+
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+ if add_watermarker:
+ self.watermark = StableDiffusionXLWatermarker()
+ else:
+ self.watermark = None
+
+ def encode_prompt(
+ self,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+ text_encoders = (
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(self, TextualInversionLoaderMixin):
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if self.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if self.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if do_classifier_free_guidance:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (ฮท) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to ฮท in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // self.vae_scale_factor,
+ int(width) // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def _get_add_time_ids(
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ def upcast_vae(self):
+ dtype = self.vae.dtype
+ self.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ self.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ LoRAXFormersAttnProcessor,
+ LoRAAttnProcessor2_0,
+ FusedAttnProcessor2_0,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ self.vae.post_quant_conv.to(dtype)
+ self.vae.decoder.conv_in.to(dtype)
+ self.vae.decoder.mid_block.to(dtype)
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.FloatTensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def guidance_rescale(self):
+ return self._guidance_rescale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def denoising_end(self):
+ return self._denoising_end
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ denoising_end: Optional[float] = None,
+ guidance_scale: float = 5.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ original_size: Optional[Tuple[int, int]] = None,
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
+ target_size: Optional[Tuple[int, int]] = None,
+ negative_original_size: Optional[Tuple[int, int]] = None,
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+ negative_target_size: Optional[Tuple[int, int]] = None,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ Anything below 512 pixels won't work well for
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co./stabilityai/stable-diffusion-xl-base-1.0)
+ and checkpoints that are not specifically fine-tuned on low resolutions.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ denoising_end (`float`, *optional*):
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+ Output**](https://huggingface.co./docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (ฮท) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `ฯ` in equation 16. of
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+ explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952).
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952).
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+ section 2.2 of [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952).
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+ micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+ To negatively condition the generation process based on a target image resolution. It should be as same
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+ [https://huggingface.co./papers/2307.01952](https://huggingface.co./papers/2307.01952). For more
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+ )
+
+ # 0. Default height and width to unet
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_size = original_size or (height, width)
+ target_size = target_size or (height, width)
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._guidance_rescale = guidance_rescale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+ self._denoising_end = denoising_end
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ lora_scale=lora_scale,
+ clip_skip=self.clip_skip,
+ )
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Prepare added time ids & embeddings
+ add_text_embeds = pooled_prompt_embeds
+ if self.text_encoder_2 is None:
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+ else:
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+ add_time_ids = self._get_add_time_ids(
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ if negative_original_size is not None and negative_target_size is not None:
+ negative_add_time_ids = self._get_add_time_ids(
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype=prompt_embeds.dtype,
+ text_encoder_projection_dim=text_encoder_projection_dim,
+ )
+ else:
+ negative_add_time_ids = add_time_ids
+
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
+
+ prompt_embeds = prompt_embeds.to(device)
+ add_text_embeds = add_text_embeds.to(device)
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 8. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ # 8.1 Apply denoising_end
+ if (
+ self.denoising_end is not None
+ and isinstance(self.denoising_end, float)
+ and self.denoising_end > 0
+ and self.denoising_end < 1
+ ):
+ discrete_timestep_cutoff = int(
+ round(
+ self.scheduler.config.num_train_timesteps
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+ )
+ )
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+ timesteps = timesteps[:num_inference_steps]
+
+ # 9. Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ self._num_timesteps = len(timesteps)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds, # [B, 77, 2048]
+ timestep_cond=timestep_cond, # None
+ cross_attention_kwargs=self.cross_attention_kwargs, # None
+ added_cond_kwargs=added_cond_kwargs, # {[B, 1280], [B, 6]}
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+ negative_pooled_prompt_embeds = callback_outputs.pop(
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+ )
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+
+ if needs_upcasting:
+ self.upcast_vae()
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+ elif latents.dtype != self.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ self.vae = self.vae.to(latents.dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+ if has_latents_mean and has_latents_std:
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+ else:
+ latents = latents / self.vae.config.scaling_factor
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if needs_upcasting:
+ self.vae.to(dtype=torch.float16)
+ else:
+ image = latents
+
+ if not output_type == "latent":
+ # apply watermark if available
+ if self.watermark is not None:
+ image = self.watermark.apply_watermark(image)
+
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return StableDiffusionXLPipelineOutput(images=image)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3c6afbdc100c993a6a21b22d5e976d8239695adb
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+diffusers==0.28.1
+accelerate==0.25.0
+datasets==2.19.1
+einops==0.8.0
+kornia==0.7.2
+numpy==1.26.4
+opencv-python==4.9.0.80
+peft==0.10.0
+pyrallis==0.3.1
+tokenizers==0.15.2
+torch==2.0.1
+torchvision==0.15.2
+transformers==4.36.2
+gradio==4.44.1
\ No newline at end of file
diff --git a/schedulers/lcm_single_step_scheduler.py b/schedulers/lcm_single_step_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a32affdc29c9dd69a35e6220c853ef0d5e98cb74
--- /dev/null
+++ b/schedulers/lcm_single_step_scheduler.py
@@ -0,0 +1,537 @@
+# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.schedulers.scheduling_utils import SchedulerMixin
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class LCMSingleStepSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's `step` function output.
+
+ Args:
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ denoised: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+# Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
+def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+
+ Args:
+ betas (`torch.FloatTensor`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = torch.cat([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+class LCMSingleStepScheduler(SchedulerMixin, ConfigMixin):
+ """
+ `LCMSingleStepScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
+ non-Markovian guidance.
+
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
+
+ Args:
+ num_train_timesteps (`int`, defaults to 1000):
+ The number of diffusion steps to train the model.
+ beta_start (`float`, defaults to 0.0001):
+ The starting `beta` value of inference.
+ beta_end (`float`, defaults to 0.02):
+ The final `beta` value.
+ beta_schedule (`str`, defaults to `"linear"`):
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, *optional*):
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
+ original_inference_steps (`int`, *optional*, defaults to 50):
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
+ clip_sample (`bool`, defaults to `True`):
+ Clip the predicted sample for numerical stability.
+ clip_sample_range (`float`, defaults to 1.0):
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, defaults to `True`):
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the alpha value at step 0.
+ steps_offset (`int`, defaults to 0):
+ An offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
+ Diffusion.
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
+ Video](https://imagen.research.google/video/paper.pdf) paper).
+ thresholding (`bool`, defaults to `False`):
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
+ as Stable Diffusion.
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
+ sample_max_value (`float`, defaults to 1.0):
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ timestep_spacing (`str`, defaults to `"leading"`):
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://huggingface.co./papers/2305.08891) for more information.
+ timestep_scaling (`float`, defaults to 10.0):
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
+ error at the default of `10.0` is already pretty small).
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
+ """
+
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.00085,
+ beta_end: float = 0.012,
+ beta_schedule: str = "scaled_linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ original_inference_steps: int = 50,
+ clip_sample: bool = False,
+ clip_sample_range: float = 1.0,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ sample_max_value: float = 1.0,
+ timestep_spacing: str = "leading",
+ timestep_scaling: float = 10.0,
+ rescale_betas_zero_snr: bool = False,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ # Rescale for zero SNR
+ if rescale_betas_zero_snr:
+ self.betas = rescale_zero_terminal_snr(self.betas)
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ self._step_index = None
+
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
+ def _init_step_index(self, timestep):
+ if isinstance(timestep, torch.Tensor):
+ timestep = timestep.to(self.timesteps.device)
+
+ index_candidates = (self.timesteps == timestep).nonzero()
+
+ # The sigma index that is taken for the **very** first `step`
+ # is always the second index (or the last index if there is only 1)
+ # This way we can ensure we don't accidentally skip a sigma in
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
+ if len(index_candidates) > 1:
+ step_index = index_candidates[1]
+ else:
+ step_index = index_candidates[0]
+
+ self._step_index = step_index.item()
+
+ @property
+ def step_index(self):
+ return self._step_index
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The input sample.
+ timestep (`int`, *optional*):
+ The current timestep in the diffusion chain.
+ Returns:
+ `torch.FloatTensor`:
+ A scaled input sample.
+ """
+ return sample
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, *remaining_dims = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
+ sample = sample.to(dtype)
+
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ device: Union[str, torch.device] = None,
+ original_inference_steps: Optional[int] = None,
+ strength: int = 1.0,
+ timesteps: Optional[list] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
+
+ Args:
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ original_inference_steps (`int`, *optional*):
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
+ schedule (which is different from the standard `diffusers` implementation). We will then take
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
+ """
+
+ if num_inference_steps is not None and timesteps is not None:
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
+
+ if timesteps is not None:
+ for i in range(1, len(timesteps)):
+ if timesteps[i] >= timesteps[i - 1]:
+ raise ValueError("`custom_timesteps` must be in descending order.")
+
+ if timesteps[0] >= self.config.num_train_timesteps:
+ raise ValueError(
+ f"`timesteps` must start before `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps}."
+ )
+
+ timesteps = np.array(timesteps, dtype=np.int64)
+ else:
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ original_steps = (
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
+ )
+
+ if original_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ if num_inference_steps > original_steps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
+ f" {original_steps} because the final timestep schedule will be a subset of the"
+ f" `original_inference_steps`-sized initial timestep schedule."
+ )
+
+ # LCM Timesteps Setting
+ # Currently, only linear spacing is supported.
+ c = self.config.num_train_timesteps // original_steps
+ # LCM Training Steps Schedule
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
+ # LCM Inference Steps Schedule
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
+
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long)
+
+ self._step_index = None
+
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
+ self.sigma_data = 0.5 # Default: 0.5
+ scaled_timestep = timestep * self.config.timestep_scaling
+
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+ def append_dims(self, x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+ def extract_into_tensor(self, a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: torch.Tensor,
+ sample: torch.FloatTensor,
+ generator: Optional[torch.Generator] = None,
+ return_dict: bool = True,
+ ) -> Union[LCMSingleStepSchedulerOutput, Tuple]:
+ """
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`):
+ The direct output from learned diffusion model.
+ timestep (`float`):
+ The current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ A current instance of a sample created by the diffusion process.
+ generator (`torch.Generator`, *optional*):
+ A random number generator.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
+ Returns:
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
+ tuple is returned where the first element is the sample tensor.
+ """
+ # 0. make sure everything is on the same device
+ alphas_cumprod = self.alphas_cumprod.to(sample.device)
+
+ # 1. compute alphas, betas
+ if timestep.ndim == 0:
+ timestep = timestep.unsqueeze(0)
+ alpha_prod_t = self.extract_into_tensor(alphas_cumprod, timestep, sample.shape)
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 2. Get scalings for boundary conditions
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
+ c_skip, c_out = [self.append_dims(x, sample.ndim) for x in [c_skip, c_out]]
+
+ # 3. Compute the predicted original sample x_0 based on the model parameterization
+ if self.config.prediction_type == "epsilon": # noise-prediction
+ predicted_original_sample = (sample - torch.sqrt(beta_prod_t) * model_output) / torch.sqrt(alpha_prod_t)
+ elif self.config.prediction_type == "sample": # x-prediction
+ predicted_original_sample = model_output
+ elif self.config.prediction_type == "v_prediction": # v-prediction
+ predicted_original_sample = torch.sqrt(alpha_prod_t) * sample - torch.sqrt(beta_prod_t) * model_output
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
+ " `v_prediction` for `LCMScheduler`."
+ )
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
+ elif self.config.clip_sample:
+ predicted_original_sample = predicted_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 5. Denoise model output using boundary conditions
+ denoised = c_out * predicted_original_sample + c_skip * sample
+
+ if not return_dict:
+ return (denoised, )
+
+ return LCMSingleStepSchedulerOutput(denoised=denoised)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/train_previewer_lora.py b/train_previewer_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1cf004f5df686efb3d0a0ff71abfd67652c4de7
--- /dev/null
+++ b/train_previewer_lora.py
@@ -0,0 +1,1712 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The LCM team and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import functools
+import gc
+import logging
+import pyrallis
+import math
+import os
+import random
+import shutil
+from contextlib import nullcontext
+from pathlib import Path
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from PIL import Image
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import ProjectConfiguration, set_seed
+from datasets import load_dataset
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from collections import namedtuple
+from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import (
+ AutoTokenizer,
+ PretrainedConfig,
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
+ AutoImageProcessor, AutoModel
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ LCMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
+from diffusers.utils import (
+ check_min_version,
+ convert_state_dict_to_diffusers,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+from basicsr.utils.degradation_pipeline import RealESRGANDegradation
+from utils.train_utils import (
+ seperate_ip_params_from_unet,
+ import_model_class_from_model_name_or_path,
+ tensor_to_pil,
+ get_train_dataset, prepare_train_dataset, collate_fn,
+ encode_prompt, importance_sampling_fn, extract_into_tensor
+
+)
+from data.data_config import DataConfig
+from losses.loss_config import LossesConfig
+from losses.losses import *
+
+from module.ip_adapter.resampler import Resampler
+from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+
+logger = get_logger(__name__)
+
+
+def prepare_latents(lq, vae, scheduler, generator, timestep):
+ transform = transforms.Compose([
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution),
+ transforms.ToTensor(),
+ ])
+ lq_pt = [transform(lq_pil.convert("RGB")) for lq_pil in lq]
+ img_pt = torch.stack(lq_pt).to(vae.device, dtype=vae.dtype)
+ img_pt = img_pt * 2.0 - 1.0
+ with torch.no_grad():
+ latents = vae.encode(img_pt).latent_dist.sample()
+ latents = latents * vae.config.scaling_factor
+ noise = torch.randn(latents.shape, generator=generator, device=vae.device, dtype=vae.dtype, layout=torch.strided).to(vae.device)
+ bsz = latents.shape[0]
+ print(f"init latent at {timestep}")
+ timestep = torch.tensor([timestep]*bsz, device=vae.device, dtype=torch.int64)
+ latents = scheduler.add_noise(latents, noise, timestep)
+ return latents
+
+
+def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ scheduler, image_encoder, image_processor,
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
+ logger.info("Running validation... ")
+
+ image_logs = []
+
+ lq = [Image.open(lq_example) for lq_example in args.validation_image]
+
+ pipe = StableDiffusionXLPipeline(
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ unet, scheduler, image_encoder, image_processor,
+ ).to(accelerator.device)
+
+ timesteps = [args.num_train_timesteps - 1]
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ latents = prepare_latents(lq, vae, scheduler, generator, timesteps[-1])
+ image = pipe(
+ prompt=[""]*len(lq),
+ ip_adapter_image=[lq],
+ num_inference_steps=1,
+ timesteps=timesteps,
+ generator=generator,
+ guidance_scale=1.0,
+ height=args.resolution,
+ width=args.resolution,
+ latents=latents,
+ ).images
+
+ if log_local:
+ # for i, img in enumerate(tensor_to_pil(lq_img)):
+ # img.save(f"./lq_{i}.png")
+ # for i, img in enumerate(tensor_to_pil(gt_img)):
+ # img.save(f"./gt_{i}.png")
+ for i, img in enumerate(image):
+ img.save(f"./lq_IPA_{i}.png")
+ return
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ images = [np.asarray(pil_img) for pil_img in image]
+ images = np.stack(images, axis=0)
+ if lq_img is not None and gt_img is not None:
+ input_lq = lq_img.detach().cpu()
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
+ input_gt = gt_img.detach().cpu()
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
+ tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
+ tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ raise NotImplementedError("Wandb logging not implemented for validation.")
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+class DDIMSolver:
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
+ # DDIM sampling parameters
+ step_ratio = timesteps // ddim_timesteps
+
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
+ self.ddim_alpha_cumprods_prev = np.asarray(
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
+ )
+ # convert to torch tensors
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
+
+ def to(self, device):
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
+ return self
+
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
+ return x_prev
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
+ return x[(...,) + (None,) * dims_to_append]
+
+
+# From LCMScheduler.get_scalings_for_boundary_condition_discrete
+def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
+ scaled_timestep = timestep_scaling * timestep
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
+ return c_skip, c_out
+
+
+# Compare LCMScheduler.step, Step 4
+def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_x_0 = (sample - sigmas * model_output) / alphas
+ elif prediction_type == "sample":
+ pred_x_0 = model_output
+ elif prediction_type == "v_prediction":
+ pred_x_0 = alphas * sample - sigmas * model_output
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_x_0
+
+
+# Based on step 4 in DDIMScheduler.step
+def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
+ if prediction_type == "epsilon":
+ pred_epsilon = model_output
+ elif prediction_type == "sample":
+ pred_epsilon = (sample - alphas * model_output) / sigmas
+ elif prediction_type == "v_prediction":
+ pred_epsilon = alphas * model_output + sigmas * sample
+ else:
+ raise ValueError(
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
+ f" are supported."
+ )
+
+ return pred_epsilon
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ # ----------Model Checkpoint Loading Arguments----------
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--teacher_revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_lcm_lora_path",
+ type=str,
+ default=None,
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--feature_extractor_path",
+ type=str,
+ default=None,
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_adapter_model_path",
+ type=str,
+ default=None,
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--adapter_tokens",
+ type=int,
+ default=64,
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
+ )
+ parser.add_argument(
+ "--use_clip_encoder",
+ action="store_true",
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
+ )
+ parser.add_argument(
+ "--image_encoder_hidden_feature",
+ action="store_true",
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
+ )
+ # ----------Training Arguments----------
+ # ----General Training Arguments----
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lcm-xl-distilled",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ # ----Logging----
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ # ----Checkpointing----
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=4000,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=5,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--save_only_adapter",
+ action="store_true",
+ help="Only save extra adapter to save space.",
+ )
+ # ----Image Processing----
+ parser.add_argument(
+ "--data_config_path",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co./docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--text_drop_rate",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--image_drop_rate",
+ type=float,
+ default=0,
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
+ )
+ parser.add_argument(
+ "--cond_drop_rate",
+ type=float,
+ default=0,
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
+ )
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--interpolation_type",
+ type=str,
+ default="bilinear",
+ help=(
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--encode_batch_size",
+ type=int,
+ default=8,
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
+ )
+ # ----Dataloader----
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ # ----Batch Size and Training Steps----
+ parser.add_argument(
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=100)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ # ----Learning Rate----
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-6,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ # ----Optimizer (Adam)----
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ # ----Diffusion Training Arguments----
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
+ parser.add_argument(
+ "--w_min",
+ type=float,
+ default=3.0,
+ required=False,
+ help=(
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--w_max",
+ type=float,
+ default=15.0,
+ required=False,
+ help=(
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
+ " compared to the original paper."
+ ),
+ )
+ parser.add_argument(
+ "--num_train_timesteps",
+ type=int,
+ default=1000,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--num_ddim_timesteps",
+ type=int,
+ default=50,
+ help="The number of timesteps to use for DDIM sampling.",
+ )
+ parser.add_argument(
+ "--losses_config_path",
+ type=str,
+ default='config_files/losses.yaml',
+ required=True,
+ help=("A yaml file containing losses to use and their weights."),
+ )
+ parser.add_argument(
+ "--loss_type",
+ type=str,
+ default="l2",
+ choices=["l2", "huber"],
+ help="The type of loss to use for the LCD loss.",
+ )
+ parser.add_argument(
+ "--huber_c",
+ type=float,
+ default=0.001,
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=64,
+ help=(
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
+ ),
+ )
+ parser.add_argument(
+ "--lora_dropout",
+ type=float,
+ default=0.0,
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
+ )
+ parser.add_argument(
+ "--lora_target_modules",
+ type=str,
+ default=None,
+ help=(
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
+ " be used. By default, LoRA will be applied to all conv and linear layers."
+ ),
+ )
+ parser.add_argument(
+ "--vae_encode_batch_size",
+ type=int,
+ default=8,
+ required=False,
+ help=(
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
+ " Encoding or decoding the whole batch at once may run into OOM issues."
+ ),
+ )
+ parser.add_argument(
+ "--timestep_scaling_factor",
+ type=float,
+ default=10.0,
+ help=(
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
+ " suffice."
+ ),
+ )
+ # ----Mixed Precision----
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ # ----Training Optimizations----
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ # ----Distributed Training----
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+ # ----------Validation Arguments----------
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=3000,
+ help="Run validation every X steps.",
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--sanity_check",
+ action="store_true",
+ help=(
+ "sanity check"
+ ),
+ )
+ # ----------Huggingface Hub Arguments-----------
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ # ----------Accelerate Arguments----------
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="trian",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co./docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ args = parser.parse_args()
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ return args
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation.
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # 1. Create the noise scheduler and the desired noise schedule.
+ noise_scheduler = DDPMScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.teacher_revision
+ )
+ noise_scheduler.config.num_train_timesteps = args.num_train_timesteps
+ lcm_scheduler = LCMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
+ # Initialize the DDIM ODE solver for distillation.
+ solver = DDIMSolver(
+ noise_scheduler.alphas_cumprod.numpy(),
+ timesteps=noise_scheduler.config.num_train_timesteps,
+ ddim_timesteps=args.num_ddim_timesteps,
+ )
+
+ # 2. Load tokenizers from SDXL checkpoint.
+ tokenizer_one = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
+ )
+ tokenizer_two = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
+ )
+
+ # 3. Load text encoders from SDXL checkpoint.
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.teacher_revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.teacher_revision, subfolder="text_encoder_2"
+ )
+
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.teacher_revision
+ )
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.teacher_revision
+ )
+
+ if args.use_clip_encoder:
+ image_processor = CLIPImageProcessor()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
+ else:
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
+
+ # 4. Load VAE from SDXL checkpoint (or more stable VAE)
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.teacher_revision,
+ )
+
+ # 7. Create online student U-Net.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.teacher_revision
+ )
+
+ # Resampler for project model in IP-Adapter
+ image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=args.adapter_tokens,
+ embedding_dim=image_encoder.config.hidden_size,
+ output_dim=unet.config.cross_attention_dim,
+ ff_mult=4
+ )
+
+ # Load the same adapter in both unet.
+ init_adapter_in_unet(
+ unet,
+ image_proj_model,
+ os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
+ adapter_tokens=args.adapter_tokens,
+ )
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ if unwrap_model(unet).dtype != torch.float32:
+ raise ValueError(
+ f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}"
+ )
+
+ if args.pretrained_lcm_lora_path is not None:
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
+ unet_state_dict = {
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ }
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ lora_state_dict = dict()
+ for k, v in unet_state_dict.items():
+ if "ip" in k:
+ k = k.replace("attn2", "attn2.processor")
+ lora_state_dict[k] = v
+ else:
+ lora_state_dict[k] = v
+ if alpha_dict:
+ args.lora_alpha = next(iter(alpha_dict.values()))
+ else:
+ args.lora_alpha = 1
+ # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
+ if args.lora_target_modules is not None:
+ lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
+ else:
+ lora_target_modules = [
+ "to_q",
+ "to_kv",
+ "0.to_out",
+ "attn1.to_k",
+ "attn1.to_v",
+ "to_k_ip",
+ "to_v_ip",
+ "ln_k_ip.linear",
+ "ln_v_ip.linear",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ]
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=lora_target_modules,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ )
+
+ # Legacy
+ # for k, v in lcm_pipe.unet.state_dict().items():
+ # if "lora" in k or "base_layer" in k:
+ # lcm_dict[k.replace("default_0", "default")] = v
+
+ unet.add_adapter(lora_config)
+ if args.pretrained_lcm_lora_path is not None:
+ incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ # 6. Freeze unet, vae, text_encoders.
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ unet.requires_grad_(False)
+
+ # 10. Handle saving and loading of checkpoints
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if args.save_only_adapter:
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))): # save adapter only
+ unet_ = unwrap_model(model)
+ # also save the checkpoints in native `diffusers` format so that it can be easily
+ # be independently loaded via `load_lora_weights()`.
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
+ StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict, safe_serialization=False)
+
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(unet))):
+ unet_ = unwrap_model(model)
+ lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
+ unet_state_dict = {
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ }
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ lora_state_dict = dict()
+ for k, v in unet_state_dict.items():
+ if "ip" in k:
+ k = k.replace("attn2", "attn2.processor")
+ lora_state_dict[k] = v
+ else:
+ lora_state_dict[k] = v
+ incompatible_keys = set_peft_model_state_dict(unet_, lora_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # 11. Enable optimizations
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co./docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ vae.enable_gradient_checkpointing()
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # 12. Optimizer creation
+ lora_params, non_lora_params = seperate_lora_params_from_unet(unet)
+ params_to_optimize = lora_params
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # 13. Dataset creation and data processing
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ datasets = []
+ datasets_name = []
+ datasets_weights = []
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
+ if args.data_config_path is not None:
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
+ for single_dataset in data_config.datasets:
+ datasets_weights.append(single_dataset.dataset_weight)
+ datasets_name.append(single_dataset.dataset_folder)
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
+ datasets.append(image_dataset)
+ # TODO: Validation dataset
+ if data_config.val_dataset is not None:
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
+
+ # Mix training datasets.
+ sampler_train = None
+ if len(datasets) == 1:
+ train_dataset = datasets[0]
+ else:
+ # Weighted each dataset
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
+ dataset_weights = []
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
+ sampler_train = torch.utils.data.WeightedRandomSampler(
+ weights=dataset_weights,
+ num_samples=len(dataset_weights)
+ )
+
+ # DataLoaders creation:
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ sampler=sampler_train,
+ shuffle=True if sampler_train is None else False,
+ collate_fn=collate_fn,
+ batch_size=args.train_batch_size,
+ num_workers=args.dataloader_num_workers,
+ )
+
+ # 14. Embeddings for the UNet.
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True):
+ def compute_time_ids(original_size, crops_coords_top_left):
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train)
+ add_text_embeds = pooled_prompt_embeds
+
+ add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)])
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ add_text_embeds = add_text_embeds.to(accelerator.device)
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
+
+ text_encoders = [text_encoder_one, text_encoder_two]
+ tokenizers = [tokenizer_one, tokenizer_two]
+
+ compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers)
+
+ # Move pixels into latents.
+ @torch.no_grad()
+ def convert_to_latent(pixels):
+ model_input = vae.encode(pixels).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ return model_input
+
+ # 15. LR Scheduler creation
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(unet, dtype=torch.float32)
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ )
+
+ # 16. Prepare for training
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # 8. Handle mixed precision and device placement
+ # For mixed precision training we cast all non-trainable weigths to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
+ # The VAE is in float32 to avoid NaN losses.
+ if args.pretrained_vae_model_name_or_path is None:
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ for p in non_lora_params:
+ p.data = p.data.to(dtype=weight_dtype)
+ for p in lora_params:
+ p.requires_grad_(True)
+ unet.to(accelerator.device)
+
+ # Also move the alpha and sigma noise schedules to accelerator.device.
+ alpha_schedule = alpha_schedule.to(accelerator.device)
+ sigma_schedule = sigma_schedule.to(accelerator.device)
+ solver = solver.to(accelerator.device)
+
+ # Instantiate Loss.
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
+ lcm_losses = list()
+ for loss_config in losses_configs.lcm_losses:
+ logger.info(f"Loading lcm loss: {loss_config.name}")
+ loss = namedtuple("loss", ["loss", "weight"])
+ loss_class = eval(loss_config.name)
+ lcm_losses.append(loss(loss_class(
+ visualize_every_k=loss_config.visualize_every_k,
+ dtype=weight_dtype,
+ accelerator=accelerator,
+ dino_model=image_encoder,
+ dino_preprocess=image_processor,
+ huber_c=args.huber_c,
+ **loss_config.init_params), weight=loss_config.weight))
+
+ # Final check.
+ for n, p in unet.named_parameters():
+ if p.requires_grad:
+ assert "lora" in n, n
+ assert p.dtype == torch.float32, n
+ else:
+ assert "lora" not in n, f"{n}"
+ assert p.dtype == weight_dtype, n
+ if args.sanity_check:
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+
+ # Check input data
+ batch = next(iter(train_dataloader))
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
+ out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
+ lcm_scheduler, image_encoder, image_processor,
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True)
+ exit()
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # 17. Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ unet.train()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(unet):
+ total_loss = torch.tensor(0.0)
+ bsz = batch["images"].shape[0]
+
+ # Drop conditions.
+ rand_tensor = torch.rand(bsz)
+ drop_image_idx = rand_tensor < args.image_drop_rate
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
+ drop_image_idx = drop_image_idx | drop_both_idx
+ drop_text_idx = drop_text_idx | drop_both_idx
+
+ with torch.no_grad():
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
+ lq_pt = image_processor(
+ images=lq_img*0.5+0.5,
+ do_rescale=False, return_tensors="pt"
+ ).pixel_values
+ image_embeds = prepare_training_image_embeds(
+ image_encoder, image_processor,
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
+ idx_to_replace=drop_image_idx
+ )
+ uncond_image_embeds = prepare_training_image_embeds(
+ image_encoder, image_processor,
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
+ device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
+ idx_to_replace=torch.ones_like(drop_image_idx)
+ )
+ # 1. Load and process the image and text conditioning
+ text, orig_size, crop_coords = (
+ batch["text"],
+ batch["original_sizes"],
+ batch["crop_top_lefts"],
+ )
+
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
+ uncond_encoded_text = compute_embeddings_fn([""]*len(text), orig_size, crop_coords)
+
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
+ gt_img = gt_img.to(dtype=vae.dtype)
+ latents = []
+ for i in range(0, gt_img.shape[0], args.vae_encode_batch_size):
+ latents.append(vae.encode(gt_img[i : i + args.vae_encode_batch_size]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+ # latents = convert_to_latent(gt_img)
+
+ latents = latents * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ latents = latents.to(weight_dtype)
+
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
+ bsz = latents.shape[0]
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
+ start_timesteps = solver.ddim_timesteps[index]
+ timesteps = start_timesteps - topk
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
+
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
+ c_skip, c_out = scalings_for_boundary_conditions(
+ timesteps, timestep_scaling=args.timestep_scaling_factor
+ )
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
+
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
+ noise = torch.randn_like(latents)
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
+
+ # 5. Sample a random guidance scale w from U[w_min, w_max]
+ # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
+ w = w.reshape(bsz, 1, 1, 1)
+ w = w.to(device=latents.device, dtype=latents.dtype)
+
+ # 6. Prepare prompt embeds and unet_added_conditions
+ prompt_embeds = encoded_text.pop("prompt_embeds")
+ encoded_text["image_embeds"] = image_embeds
+ uncond_prompt_embeds = uncond_encoded_text.pop("prompt_embeds")
+ uncond_encoded_text["image_embeds"] = image_embeds
+
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
+ noise_pred = unet(
+ noisy_model_input,
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds,
+ added_cond_kwargs=uncond_encoded_text,
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ noise_pred,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
+
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
+ # solver timestep.
+
+ # With the adapters disabled, the `unet` is the regular teacher model.
+ accelerator.unwrap_model(unet).disable_adapters()
+ with torch.no_grad():
+
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
+ teacher_added_cond = dict()
+ for k,v in encoded_text.items():
+ if isinstance(v, torch.Tensor):
+ teacher_added_cond[k] = v.to(weight_dtype)
+ else:
+ teacher_image_embeds = []
+ for img_emb in v:
+ teacher_image_embeds.append(img_emb.to(weight_dtype))
+ teacher_added_cond[k] = teacher_image_embeds
+ cond_teacher_output = unet(
+ noisy_model_input,
+ start_timesteps,
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs=teacher_added_cond,
+ ).sample
+ cond_pred_x0 = get_predicted_original_sample(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ cond_pred_noise = get_predicted_noise(
+ cond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
+ teacher_added_uncond = dict()
+ uncond_encoded_text["image_embeds"] = uncond_image_embeds
+ for k,v in uncond_encoded_text.items():
+ if isinstance(v, torch.Tensor):
+ teacher_added_uncond[k] = v.to(weight_dtype)
+ else:
+ teacher_uncond_image_embeds = []
+ for img_emb in v:
+ teacher_uncond_image_embeds.append(img_emb.to(weight_dtype))
+ teacher_added_uncond[k] = teacher_uncond_image_embeds
+ uncond_teacher_output = unet(
+ noisy_model_input,
+ start_timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs=teacher_added_uncond,
+ ).sample
+ uncond_pred_x0 = get_predicted_original_sample(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ uncond_pred_noise = get_predicted_noise(
+ uncond_teacher_output,
+ start_timesteps,
+ noisy_model_input,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
+ # augmented PF-ODE trajectory (solving backward in time)
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(weight_dtype)
+
+ # re-enable unet adapters to turn the `unet` into a student unet.
+ accelerator.unwrap_model(unet).enable_adapters()
+
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
+ # Note that we do not use a separate target network for LCM-LoRA distillation.
+ with torch.no_grad():
+ uncond_encoded_text["image_embeds"] = image_embeds
+ target_added_cond = dict()
+ for k,v in uncond_encoded_text.items():
+ if isinstance(v, torch.Tensor):
+ target_added_cond[k] = v.to(weight_dtype)
+ else:
+ target_image_embeds = []
+ for img_emb in v:
+ target_image_embeds.append(img_emb.to(weight_dtype))
+ target_added_cond[k] = target_image_embeds
+ target_noise_pred = unet(
+ x_prev,
+ timesteps,
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
+ added_cond_kwargs=target_added_cond,
+ ).sample
+ pred_x_0 = get_predicted_original_sample(
+ target_noise_pred,
+ timesteps,
+ x_prev,
+ noise_scheduler.config.prediction_type,
+ alpha_schedule,
+ sigma_schedule,
+ )
+ target = c_skip * x_prev + c_out * pred_x_0
+
+ # 10. Calculate loss
+ lcm_loss_arguments = {
+ "target": target.float(),
+ "predict": model_pred.float(),
+ }
+ loss_dict = dict()
+ # total_loss = total_loss + torch.mean(
+ # torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
+ # )
+ # loss_dict["L2Loss"] = total_loss.item()
+ for loss_config in lcm_losses:
+ if loss_config.loss.__class__.__name__=="DINOLoss":
+ with torch.no_grad():
+ pixel_target = []
+ latent_target = target.to(dtype=vae.dtype)
+ for i in range(0, latent_target.shape[0], args.vae_encode_batch_size):
+ pixel_target.append(
+ vae.decode(
+ latent_target[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
+ return_dict=False
+ )[0]
+ )
+ pixel_target = torch.cat(pixel_target, dim=0)
+ pixel_pred = []
+ latent_pred = model_pred.to(dtype=vae.dtype)
+ for i in range(0, latent_pred.shape[0], args.vae_encode_batch_size):
+ pixel_pred.append(
+ vae.decode(
+ latent_pred[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
+ return_dict=False
+ )[0]
+ )
+ pixel_pred = torch.cat(pixel_pred, dim=0)
+ dino_loss_arguments = {
+ "target": pixel_target,
+ "predict": pixel_pred,
+ }
+ non_weighted_loss = loss_config.loss(**dino_loss_arguments, accelerator=accelerator)
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
+ total_loss = total_loss + non_weighted_loss * loss_config.weight
+ else:
+ non_weighted_loss = loss_config.loss(**lcm_loss_arguments, accelerator=accelerator)
+ total_loss = total_loss + non_weighted_loss * loss_config.weight
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
+
+ # 11. Backpropagate on the online student model (`unet`) (only LoRA)
+ accelerator.backward(total_loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
+ lcm_scheduler, image_encoder, image_processor,
+ args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False, log_local=False)
+
+ logs = dict()
+ # logs.update({"loss": loss.detach().item()})
+ logs.update(loss_dict)
+ logs.update({"lr": lr_scheduler.get_last_lr()[0]})
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ unet = accelerator.unwrap_model(unet)
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
+ StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
+
+ if args.push_to_hub:
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ del unet
+ torch.cuda.empty_cache()
+
+ # Final inference.
+ if args.validation_steps is not None:
+ log_validation(unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ lcm_scheduler, image_encoder=None, image_processor=None,
+ args=args, accelerator=accelerator, weight_dtype=weight_dtype, step=0, is_final_validation=False, log_local=True)
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
\ No newline at end of file
diff --git a/train_previewer_lora.sh b/train_previewer_lora.sh
new file mode 100644
index 0000000000000000000000000000000000000000..29db5164dbb77f8025c67bd0f0ca379c793d2212
--- /dev/null
+++ b/train_previewer_lora.sh
@@ -0,0 +1,24 @@
+# After DCP training, distill the Previewer with DCP in `train_previewer_lora.py`:
+accelerate launch --num_processes train_previewer_lora.py \
+ --output_dir \
+ --train_data_dir \
+ --logging_dir \
+ --pretrained_model_name_or_path \
+ --feature_extractor_path \
+ --pretrained_adapter_model_path \
+ --losses_config_path config_files/losses.yaml \
+ --data_config_path config_files/IR_dataset.yaml \
+ --save_only_adapter \
+ --gradient_checkpointing \
+ --num_train_timesteps 1000 \
+ --num_ddim_timesteps 50 \
+ --lora_alpha 1 \
+ --mixed_precision fp16 \
+ --train_batch_size 32 \
+ --vae_encode_batch_size 16 \
+ --gradient_accumulation_steps 1 \
+ --learning_rate 1e-4 \
+ --lr_warmup_steps 1000 \
+ --lr_scheduler cosine \
+ --lr_num_cycles 1 \
+ --resume_from_checkpoint latest
\ No newline at end of file
diff --git a/train_stage1_adapter.py b/train_stage1_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec2ee3326ec2c6b85c0e164cc3c1282b76d91143
--- /dev/null
+++ b/train_stage1_adapter.py
@@ -0,0 +1,1259 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import contextlib
+import time
+import gc
+import logging
+import math
+import os
+import random
+import jsonlines
+import functools
+import shutil
+import pyrallis
+import itertools
+from pathlib import Path
+from collections import namedtuple, OrderedDict
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from packaging import version
+from PIL import Image
+from data.data_config import DataConfig
+from basicsr.utils.degradation_pipeline import RealESRGANDegradation
+from losses.loss_config import LossesConfig
+from losses.losses import *
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import (
+ AutoTokenizer,
+ PretrainedConfig,
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
+ AutoImageProcessor, AutoModel)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ AutoencoderTiny,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
+from utils.train_utils import (
+ seperate_ip_params_from_unet,
+ import_model_class_from_model_name_or_path,
+ tensor_to_pil,
+ get_train_dataset, prepare_train_dataset, collate_fn,
+ encode_prompt, importance_sampling_fn, extract_into_tensor
+)
+from module.ip_adapter.resampler import Resampler
+from module.ip_adapter.attention_processor import init_attn_proc
+from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
+
+
+if is_wandb_available():
+ import wandb
+
+
+logger = get_logger(__name__)
+
+
+def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
+ logger.info("Running validation... ")
+
+ image_logs = []
+
+ lq = [Image.open(lq_example) for lq_example in args.validation_image]
+
+ pipe = StableDiffusionXLPipeline(
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ unet, scheduler, image_encoder, image_processor,
+ ).to(accelerator.device)
+
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ image = pipe(
+ prompt=[""]*len(lq),
+ ip_adapter_image=[lq],
+ num_inference_steps=20,
+ generator=generator,
+ guidance_scale=5.0,
+ height=args.resolution,
+ width=args.resolution,
+ ).images
+
+ if log_local:
+ for i, img in enumerate(tensor_to_pil(lq_img)):
+ img.save(f"./lq_{i}.png")
+ for i, img in enumerate(tensor_to_pil(gt_img)):
+ img.save(f"./gt_{i}.png")
+ for i, img in enumerate(image):
+ img.save(f"./lq_IPA_{i}.png")
+ return
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ images = [np.asarray(pil_img) for pil_img in image]
+ images = np.stack(images, axis=0)
+ if lq_img is not None and gt_img is not None:
+ input_lq = lq_img.detach().cpu()
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
+ input_gt = gt_img.detach().cpu()
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
+ tracker.writer.add_images("lq", input_lq[0], step, dataformats="CHW")
+ tracker.writer.add_images("gt", input_gt[0], step, dataformats="CHW")
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ raise NotImplementedError("Wandb logging not implemented for validation.")
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="InstantIR stage-1 training.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--feature_extractor_path",
+ type=str,
+ default=None,
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_adapter_model_path",
+ type=str,
+ default=None,
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--adapter_tokens",
+ type=int,
+ default=64,
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
+ )
+ parser.add_argument(
+ "--use_clip_encoder",
+ action="store_true",
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
+ )
+ parser.add_argument(
+ "--image_encoder_hidden_feature",
+ action="store_true",
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
+ )
+ parser.add_argument(
+ "--losses_config_path",
+ type=str,
+ required=True,
+ default='config_files/losses.yaml'
+ help=("A yaml file containing losses to use and their weights."),
+ )
+ parser.add_argument(
+ "--data_config_path",
+ type=str,
+ default='config_files/IR_dataset.yaml',
+ help=("A folder containing the training data. "),
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="stage1_model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=2000,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co./docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=5,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--save_only_adapter",
+ action="store_true",
+ help="Only save extra adapter to save space.",
+ )
+ parser.add_argument(
+ "--importance_sampling",
+ action="store_true",
+ help="Whether or not to use importance sampling.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that ๐ค Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co./docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--text_drop_rate",
+ type=float,
+ default=0.05,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--image_drop_rate",
+ type=float,
+ default=0.05,
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
+ )
+ parser.add_argument(
+ "--cond_drop_rate",
+ type=float,
+ default=0.05,
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
+ )
+ parser.add_argument(
+ "--sanity_check",
+ action="store_true",
+ help=(
+ "sanity check"
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=3000,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default="instantir_stage1",
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co./docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ # if args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
+ # raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.text_drop_rate < 0 or args.text_drop_rate > 1:
+ raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ # kwargs_handlers=[kwargs],
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation.
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+ # Importance sampling.
+ list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
+ prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
+ importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
+ importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
+
+ # Load the tokenizers
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_2 = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # Text encoder and image encoder.
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+ text_encoder = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_2 = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ if args.use_clip_encoder:
+ image_processor = CLIPImageProcessor()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
+ else:
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
+
+ # VAE.
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ # UNet.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ variant=args.variant
+ )
+
+ pipe = StableDiffusionXLPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ unet=unet,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ vae=vae,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ variant=args.variant
+ )
+
+ # Resampler for project model in IP-Adapter
+ image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=args.adapter_tokens,
+ embedding_dim=image_encoder.config.hidden_size,
+ output_dim=unet.config.cross_attention_dim,
+ ff_mult=4
+ )
+
+ init_adapter_in_unet(
+ unet,
+ image_proj_model,
+ os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
+ adapter_tokens=args.adapter_tokens,
+ )
+
+ # Initialize training state.
+ vae.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ text_encoder_2.requires_grad_(False)
+ unet.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if args.save_only_adapter:
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ if isinstance(model, type(unwrap_model(unet))): # save adapter only
+ adapter_state_dict = OrderedDict()
+ adapter_state_dict["image_proj"] = model.encoder_hid_proj.image_projection_layers[0].state_dict()
+ adapter_state_dict["ip_adapter"] = torch.nn.ModuleList(model.attn_processors.values()).state_dict()
+ torch.save(adapter_state_dict, os.path.join(output_dir, "adapter_ckpt.pt"))
+
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
+ adapter_state_dict = torch.load(os.path.join(input_dir, "adapter_ckpt.pt"), map_location="cpu")
+ if list(adapter_state_dict.keys()) != ["image_proj", "ip_adapter"]:
+ from module.ip_adapter.utils import revise_state_dict
+ adapter_state_dict = revise_state_dict(adapter_state_dict)
+ model.encoder_hid_proj.image_projection_layers[0].load_state_dict(adapter_state_dict["image_proj"], strict=True)
+ missing, unexpected = torch.nn.ModuleList(model.attn_processors.values()).load_state_dict(adapter_state_dict["ip_adapter"], strict=False)
+ if len(unexpected) > 0:
+ raise ValueError(f"Unexpected keys: {unexpected}")
+ if len(missing) > 0:
+ for mk in missing:
+ if "ln" not in mk:
+ raise ValueError(f"Missing keys: {missing}")
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co./docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ vae.enable_gradient_checkpointing()
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation.
+ ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
+ params_to_optimize = ip_params
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Instantiate Loss.
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
+ diffusion_losses = list()
+ for loss_config in losses_configs.diffusion_losses:
+ logger.info(f"Loading diffusion loss: {loss_config.name}")
+ loss = namedtuple("loss", ["loss", "weight"])
+ loss_class = eval(loss_config.name)
+ diffusion_losses.append(loss(loss_class(visualize_every_k=loss_config.visualize_every_k,
+ dtype=weight_dtype,
+ accelerator=accelerator,
+ **loss_config.init_params), weight=loss_config.weight))
+
+ # SDXL additional condition that will be added to time embedding.
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ # Text prompt embeddings.
+ @torch.no_grad()
+ def compute_embeddings(batch, text_encoders, tokenizers, drop_idx=None, is_train=True):
+ prompt_batch = batch[args.caption_column]
+ if drop_idx is not None:
+ for i in range(len(prompt_batch)):
+ prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, is_train
+ )
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+ sdxl_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
+
+ return prompt_embeds, sdxl_added_cond_kwargs
+
+ # Move pixels into latents.
+ @torch.no_grad()
+ def convert_to_latent(pixels):
+ model_input = vae.encode(pixels).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ return model_input
+
+ # Datasets and other data moduels.
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ text_encoders=[text_encoder, text_encoder_2],
+ tokenizers=[tokenizer, tokenizer_2],
+ is_train=True,
+ )
+
+ datasets = []
+ datasets_name = []
+ datasets_weights = []
+ if args.data_config_path is not None:
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
+ for single_dataset in data_config.datasets:
+ datasets_weights.append(single_dataset.dataset_weight)
+ datasets_name.append(single_dataset.dataset_folder)
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
+ datasets.append(image_dataset)
+ # TODO: Validation dataset
+ if data_config.val_dataset is not None:
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
+
+ # Mix training datasets.
+ sampler_train = None
+ if len(datasets) == 1:
+ train_dataset = datasets[0]
+ else:
+ # Weighted each dataset
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
+ dataset_weights = []
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
+ sampler_train = torch.utils.data.WeightedRandomSampler(
+ weights=dataset_weights,
+ num_samples=len(dataset_weights)
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ sampler=sampler_train,
+ shuffle=True if sampler_train is None else False,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers
+ )
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
+ if args.pretrained_vae_model_name_or_path is None:
+ # The VAE is fp32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ importance_ratio = importance_ratio.to(accelerator.device)
+ for non_ip_param in non_ip_params:
+ non_ip_param.data = non_ip_param.data.to(dtype=weight_dtype)
+ for ip_param in ip_params:
+ ip_param.requires_grad_(True)
+ unet.to(accelerator.device)
+
+ # Final check.
+ for n, p in unet.named_parameters():
+ if p.requires_grad: assert p.dtype == torch.float32, n
+ else: assert p.dtype == weight_dtype, n
+ if args.sanity_check:
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+
+ # Check input data
+ batch = next(iter(train_dataloader))
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
+ images_log = log_validation(
+ unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True
+ )
+ exit()
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ trainable_models = [unet]
+
+ if args.gradient_checkpointing:
+ checkpoint_models = []
+ else:
+ checkpoint_models = []
+
+ image_logs = None
+ tic = time.time()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ toc = time.time()
+ io_time = toc - tic
+ tic = toc
+ for model in trainable_models + checkpoint_models:
+ model.train()
+ with accelerator.accumulate(*trainable_models):
+ loss = torch.tensor(0.0)
+
+ # Drop conditions.
+ rand_tensor = torch.rand(batch["images"].shape[0])
+ drop_image_idx = rand_tensor < args.image_drop_rate
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
+ drop_image_idx = drop_image_idx | drop_both_idx
+ drop_text_idx = drop_text_idx | drop_both_idx
+
+ # Get LQ embeddings
+ with torch.no_grad():
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
+ lq_pt = image_processor(
+ images=lq_img*0.5+0.5,
+ do_rescale=False, return_tensors="pt"
+ ).pixel_values
+ image_embeds = prepare_training_image_embeds(
+ image_encoder, image_processor,
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
+ idx_to_replace=drop_image_idx
+ )
+
+ # Process text inputs.
+ prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
+ added_conditions["image_embeds"] = image_embeds
+
+ # Move inputs to latent space.
+ gt_img = gt_img.to(dtype=vae.dtype)
+ model_input = convert_to_latent(gt_img)
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+
+ # Sample noise that we'll add to the latents.
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image.
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+ loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
+
+ toc = time.time()
+ prepare_time = toc - tic
+ tic = time.time()
+
+ model_pred = unet(
+ noisy_model_input, timesteps,
+ encoder_hidden_states=prompt_embeds_input,
+ added_cond_kwargs=added_conditions,
+ return_dict=False
+ )[0]
+
+ diffusion_loss_arguments = {
+ "target": noise,
+ "predict": model_pred,
+ "prompt_embeddings_input": prompt_embeds_input,
+ "timesteps": timesteps,
+ "weights": loss_weights,
+ }
+
+ loss_dict = dict()
+ for loss_config in diffusion_losses:
+ non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
+ loss = loss + non_weighted_loss * loss_config.weight
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ toc = time.time()
+ forward_time = toc - tic
+ tic = toc
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ image_logs = log_validation(unwrap_model(unet), vae,
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False)
+
+ logs = {}
+ logs.update(loss_dict)
+ logs.update({
+ "lr": lr_scheduler.get_last_lr()[0],
+ "io_time": io_time,
+ "prepare_time": prepare_time,
+ "forward_time": forward_time
+ })
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+ tic = time.time()
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ accelerator.save_state(os.path.join(args.output_dir, "last"), safe_serialization=False)
+ # Run a final round of validation.
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
+ image_logs = None
+ if args.validation_image is not None:
+ image_logs = log_validation(
+ unwrap_model(unet), vae,
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, global_step,
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
\ No newline at end of file
diff --git a/train_stage1_adapter.sh b/train_stage1_adapter.sh
new file mode 100644
index 0000000000000000000000000000000000000000..71fee17856de283ec47183a61bfb6d5b4ae9ec72
--- /dev/null
+++ b/train_stage1_adapter.sh
@@ -0,0 +1,17 @@
+# Stage 1: training lq adapter
+accelerate launch --num_processes train_stage1_adapter.py \
+ --output_dir \
+ --train_data_dir \
+ --logging_dir \
+ --pretrained_model_name_or_path \
+ --feature_extractor_path \
+ --save_only_adapter \
+ --gradient_checkpointing \
+ --mixed_precision fp16 \
+ --train_batch_size 96 \
+ --gradient_accumulation_steps 1 \
+ --learning_rate 1e-4 \
+ --lr_warmup_steps 1000 \
+ --lr_scheduler cosine \
+ --lr_num_cycles 1 \
+ --resume_from_checkpoint latest
\ No newline at end of file
diff --git a/train_stage2_aggregator.py b/train_stage2_aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c617d8ac295e5a7bca637cc029a5af6c85be27
--- /dev/null
+++ b/train_stage2_aggregator.py
@@ -0,0 +1,1698 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import os
+import argparse
+import time
+import gc
+import logging
+import math
+import copy
+import random
+import yaml
+import functools
+import shutil
+import pyrallis
+from pathlib import Path
+from collections import namedtuple, OrderedDict
+
+import accelerate
+import numpy as np
+import torch
+from safetensors import safe_open
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
+from huggingface_hub import create_repo, upload_folder
+from packaging import version
+from PIL import Image
+from data.data_config import DataConfig
+from basicsr.utils.degradation_pipeline import RealESRGANDegradation
+from losses.loss_config import LossesConfig
+from losses.losses import *
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+from transformers import (
+ AutoTokenizer,
+ PretrainedConfig,
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
+ AutoImageProcessor, AutoModel
+)
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ StableDiffusionXLPipeline,
+ UNet2DConditionModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.utils import (
+ check_min_version,
+ convert_unet_state_dict_to_peft,
+ is_wandb_available,
+)
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+from module.aggregator import Aggregator
+from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
+from module.ip_adapter.ip_adapter import MultiIPAdapterImageProjection
+from module.ip_adapter.resampler import Resampler
+from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
+from module.ip_adapter.attention_processor import init_attn_proc
+from utils.train_utils import (
+ seperate_ip_params_from_unet,
+ import_model_class_from_model_name_or_path,
+ tensor_to_pil,
+ get_train_dataset, prepare_train_dataset, collate_fn,
+ encode_prompt, importance_sampling_fn, extract_into_tensor
+)
+from pipelines.sdxl_instantir import InstantIRPipeline
+
+
+if is_wandb_available():
+ import wandb
+
+
+logger = get_logger(__name__)
+
+
+def log_validation(unet, aggregator, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
+ logger.info("Running validation... ")
+
+ image_logs = []
+
+ # validation_batch = batchify_pil(args.validation_image, args.validation_prompt, deg_pipeline, image_processor)
+ lq = [Image.open(lq_example).convert("RGB") for lq_example in args.validation_image]
+
+ pipe = InstantIRPipeline(
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ unet, scheduler, aggregator, feature_extractor=image_processor, image_encoder=image_encoder,
+ ).to(accelerator.device)
+
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
+ if lq_img is not None and gt_img is not None:
+ lq_img = lq_img[:len(args.validation_image)]
+ lq_pt = image_processor(
+ images=lq_img*0.5+0.5,
+ do_rescale=False, return_tensors="pt"
+ ).pixel_values
+ image = pipe(
+ prompt=[""]*len(lq_img),
+ image=lq_img,
+ ip_adapter_image=lq_pt,
+ num_inference_steps=20,
+ generator=generator,
+ controlnet_conditioning_scale=1.0,
+ negative_prompt=[""]*len(lq),
+ guidance_scale=5.0,
+ height=args.resolution,
+ width=args.resolution,
+ lcm_scheduler=lcm_scheduler,
+ ).images
+ else:
+ image = pipe(
+ prompt=[""]*len(lq),
+ image=lq,
+ ip_adapter_image=lq,
+ num_inference_steps=20,
+ generator=generator,
+ controlnet_conditioning_scale=1.0,
+ negative_prompt=[""]*len(lq),
+ guidance_scale=5.0,
+ height=args.resolution,
+ width=args.resolution,
+ lcm_scheduler=lcm_scheduler,
+ ).images
+
+ if log_local:
+ for i, rec_image in enumerate(image):
+ rec_image.save(f"./instantid_{i}.png")
+ return
+
+ tracker_key = "test" if is_final_validation else "validation"
+ for tracker in accelerator.trackers:
+ if tracker.name == "tensorboard":
+ images = [np.asarray(pil_img) for pil_img in image]
+ images = np.stack(images, axis=0)
+ if lq_img is not None and gt_img is not None:
+ input_lq = lq_img.cpu()
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
+ input_gt = gt_img.cpu()
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
+ tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
+ tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
+ elif tracker.name == "wandb":
+ raise NotImplementedError("Wandb logging not implemented for validation.")
+ formatted_images = []
+
+ for log in image_logs:
+ images = log["images"]
+ validation_prompt = log["validation_prompt"]
+ validation_image = log["validation_image"]
+
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
+
+ for image in images:
+ image = wandb.Image(image, caption=validation_prompt)
+ formatted_images.append(image)
+
+ tracker.log({tracker_key: formatted_images})
+ else:
+ logger.warning(f"image logging not implemented for {tracker.name}")
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ return image_logs
+
+
+def remove_attn2(model):
+ def recursive_find_module(name, module):
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
+ elif "resnets" in name: return
+ if hasattr(module, "attn2"):
+ setattr(module, "attn2", None)
+ setattr(module, "norm2", None)
+ return
+ for sub_name, sub_module in module.named_children():
+ recursive_find_module(f"{name}.{sub_name}", sub_module)
+
+ for name, module in model.named_children():
+ recursive_find_module(name, module)
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a IP-Adapter training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--controlnet_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to an pretrained controlnet model like tile-controlnet.",
+ )
+ parser.add_argument(
+ "--use_lcm",
+ action="store_true",
+ help="Whether or not to use lcm unet.",
+ )
+ parser.add_argument(
+ "--pretrained_lcm_lora_path",
+ type=str,
+ default=None,
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--lora_rank",
+ type=int,
+ default=64,
+ help="The rank of the LoRA projection matrix.",
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=64,
+ help=(
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
+ ),
+ )
+ parser.add_argument(
+ "--lora_dropout",
+ type=float,
+ default=0.0,
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
+ )
+ parser.add_argument(
+ "--lora_target_modules",
+ type=str,
+ default=None,
+ help=(
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
+ " be used. By default, LoRA will be applied to all conv and linear layers."
+ ),
+ )
+ parser.add_argument(
+ "--feature_extractor_path",
+ type=str,
+ default=None,
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_adapter_model_path",
+ type=str,
+ default=None,
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--adapter_tokens",
+ type=int,
+ default=64,
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
+ )
+ parser.add_argument(
+ "--aggregator_adapter",
+ action="store_true",
+ help="Whether or not to add adapter on aggregator.",
+ )
+ parser.add_argument(
+ "--optimize_adapter",
+ action="store_true",
+ help="Whether or not to optimize IP-Adapter.",
+ )
+ parser.add_argument(
+ "--image_encoder_hidden_feature",
+ action="store_true",
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
+ )
+ parser.add_argument(
+ "--losses_config_path",
+ type=str,
+ required=True,
+ help=("A yaml file containing losses to use and their weights."),
+ )
+ parser.add_argument(
+ "--data_config_path",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--tokenizer_name",
+ type=str,
+ default=None,
+ help="Pretrained tokenizer name or path if not the same as model_name",
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="stage1_model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=3000,
+ help=(
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
+ "See https://huggingface.co./docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
+ "instructions."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=5,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--previous_ckpt",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be initialized from a previous checkpoint."
+ ),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--save_only_adapter",
+ action="store_true",
+ help="Only save extra adapter to save space.",
+ )
+ parser.add_argument(
+ "--cache_prompt_embeds",
+ action="store_true",
+ help="Whether or not to cache prompt embeds to save memory.",
+ )
+ parser.add_argument(
+ "--importance_sampling",
+ action="store_true",
+ help="Whether or not to use importance sampling.",
+ )
+ parser.add_argument(
+ "--CFG_scale",
+ type=float,
+ default=1.0,
+ help="CFG for previewer.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
+ )
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+ parser.add_argument(
+ "--set_grads_to_none",
+ action="store_true",
+ help=(
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
+ " behaviors, so disable this argument if it causes any problems. More info:"
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
+ ),
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that ๐ค Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--train_data_dir",
+ type=str,
+ default=None,
+ help=(
+ "A folder containing the training data. Folder contents must follow the structure described in"
+ " https://huggingface.co./docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
+ ),
+ )
+ parser.add_argument(
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
+ )
+ parser.add_argument(
+ "--conditioning_image_column",
+ type=str,
+ default="conditioning_image",
+ help="The column of the dataset containing the controlnet conditioning image.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default="text",
+ help="The column of the dataset containing a caption or a list of captions.",
+ )
+ parser.add_argument(
+ "--max_train_samples",
+ type=int,
+ default=None,
+ help=(
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ ),
+ )
+ parser.add_argument(
+ "--text_drop_rate",
+ type=float,
+ default=0,
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
+ )
+ parser.add_argument(
+ "--image_drop_rate",
+ type=float,
+ default=0,
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
+ )
+ parser.add_argument(
+ "--cond_drop_rate",
+ type=float,
+ default=0,
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
+ )
+ parser.add_argument(
+ "--use_ema_adapter",
+ action="store_true",
+ help=(
+ "use ema ip-adapter for LCM preview"
+ ),
+ )
+ parser.add_argument(
+ "--sanity_check",
+ action="store_true",
+ help=(
+ "sanity check"
+ ),
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
+ ),
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ nargs="+",
+ help=(
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
+ " `--validation_image` that will be used with all `--validation_prompt`s."
+ ),
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
+ )
+ parser.add_argument(
+ "--validation_steps",
+ type=int,
+ default=4000,
+ help=(
+ "Run validation every X steps. Validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
+ " and logging the images."
+ ),
+ )
+ parser.add_argument(
+ "--tracker_project_name",
+ type=str,
+ default='train',
+ help=(
+ "The `project_name` argument passed to Accelerator.init_trackers for"
+ " more information see https://huggingface.co./docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
+ ),
+ )
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if not args.sanity_check and args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
+
+ if args.dataset_name is not None and args.train_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
+
+ if args.text_drop_rate < 0 or args.text_drop_rate > 1:
+ raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
+
+ if args.validation_prompt is not None and args.validation_image is None:
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
+
+ if args.validation_prompt is None and args.validation_image is not None:
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
+
+ if (
+ args.validation_image is not None
+ and args.validation_prompt is not None
+ and len(args.validation_image) != 1
+ and len(args.validation_prompt) != 1
+ and len(args.validation_image) != len(args.validation_prompt)
+ ):
+ raise ValueError(
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
+ )
+
+ if args.resolution % 8 != 0:
+ raise ValueError(
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
+ )
+
+ return args
+
+
+def update_ema_model(ema_model, model, ema_beta):
+ for ema_param, param in zip(ema_model.parameters(), model.parameters()):
+ ema_param.copy_(param.detach().lerp(ema_param, ema_beta))
+
+
+def copy_dict(dict):
+ new_dict = {}
+ for key, value in dict.items():
+ new_dict[key] = value
+ return new_dict
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `huggingface-cli login` to authenticate with the Hub."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ )
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Handle the repository creation.
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Load scheduler and models
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
+
+ # Importance sampling.
+ list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
+ prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
+ importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
+ importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
+
+ # Load the tokenizers
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ use_fast=False,
+ )
+ tokenizer_2 = AutoTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ use_fast=False,
+ )
+
+ # Text encoder and image encoder.
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+ text_encoder = text_encoder_cls_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_2 = text_encoder_cls_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+
+ # Image processor and image encoder.
+ if args.use_clip_encoder:
+ image_processor = CLIPImageProcessor()
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
+ else:
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
+
+ # VAE.
+ vae_path = (
+ args.pretrained_model_name_or_path
+ if args.pretrained_vae_model_name_or_path is None
+ else args.pretrained_vae_model_name_or_path
+ )
+ vae = AutoencoderKL.from_pretrained(
+ vae_path,
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
+ revision=args.revision,
+ variant=args.variant,
+ )
+
+ # UNet.
+ unet = UNet2DConditionModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="unet",
+ revision=args.revision,
+ variant=args.variant
+ )
+
+ # Aggregator.
+ aggregator = Aggregator.from_unet(unet)
+ remove_attn2(aggregator)
+ if args.controlnet_model_name_or_path:
+ logger.info("Loading existing controlnet weights")
+ if args.controlnet_model_name_or_path.endswith(".safetensors"):
+ pretrained_cn_state_dict = {}
+ with safe_open(args.controlnet_model_name_or_path, framework="pt", device='cpu') as f:
+ for key in f.keys():
+ pretrained_cn_state_dict[key] = f.get_tensor(key)
+ else:
+ pretrained_cn_state_dict = torch.load(os.path.join(args.controlnet_model_name_or_path, "aggregator_ckpt.pt"), map_location="cpu")
+ aggregator.load_state_dict(pretrained_cn_state_dict, strict=True)
+ else:
+ logger.info("Initializing aggregator weights from unet.")
+
+ # Create image embedding projector for IP-Adapters.
+ if args.pretrained_adapter_model_path is not None:
+ if args.pretrained_adapter_model_path.endswith(".safetensors"):
+ pretrained_adapter_state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(args.pretrained_adapter_model_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ pretrained_adapter_state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ pretrained_adapter_state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ pretrained_adapter_state_dict = torch.load(args.pretrained_adapter_model_path, map_location="cpu")
+
+ # Image embedding Projector.
+ image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=args.adapter_tokens,
+ embedding_dim=image_encoder.config.hidden_size,
+ output_dim=unet.config.cross_attention_dim,
+ ff_mult=4
+ )
+
+ init_adapter_in_unet(
+ unet,
+ image_proj_model,
+ pretrained_adapter_state_dict,
+ adapter_tokens=args.adapter_tokens,
+ )
+
+ # EMA adapter for LCM preview.
+ if args.use_ema_adapter:
+ assert args.optimize_adapter, "No need for EMA with frozen adapter."
+ ema_image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=args.adapter_tokens,
+ embedding_dim=image_encoder.config.hidden_size,
+ output_dim=unet.config.cross_attention_dim,
+ ff_mult=4
+ )
+ orig_encoder_hid_proj = unet.encoder_hid_proj
+ ema_encoder_hid_proj = MultiIPAdapterImageProjection([ema_image_proj_model])
+ orig_attn_procs = unet.attn_processors
+ orig_attn_procs_list = torch.nn.ModuleList(orig_attn_procs.values())
+ ema_attn_procs = init_attn_proc(unet, args.adapter_tokens, True, True, False)
+ ema_attn_procs_list = torch.nn.ModuleList(ema_attn_procs.values())
+ ema_attn_procs_list.requires_grad_(False)
+ ema_encoder_hid_proj.requires_grad_(False)
+
+ # Initialize EMA state.
+ ema_beta = 0.5 ** (args.ema_update_steps / max(args.ema_halflife_steps, 1e-8))
+ logger.info(f"Using EMA with beta: {ema_beta}")
+ ema_encoder_hid_proj.load_state_dict(orig_encoder_hid_proj.state_dict())
+ ema_attn_procs_list.load_state_dict(orig_attn_procs_list.state_dict())
+
+ # Projector for aggregator.
+ if args.aggregator_adapter:
+ image_proj_model = Resampler(
+ dim=1280,
+ depth=4,
+ dim_head=64,
+ heads=20,
+ num_queries=args.adapter_tokens,
+ embedding_dim=image_encoder.config.hidden_size,
+ output_dim=unet.config.cross_attention_dim,
+ ff_mult=4
+ )
+
+ init_adapter_in_unet(
+ aggregator,
+ image_proj_model,
+ pretrained_adapter_state_dict,
+ adapter_tokens=args.adapter_tokens,
+ )
+ del pretrained_adapter_state_dict
+
+ # Load LCM LoRA into unet.
+ if args.pretrained_lcm_lora_path is not None:
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
+ unet_state_dict = {
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
+ }
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
+ lora_state_dict = dict()
+ for k, v in unet_state_dict.items():
+ if "ip" in k:
+ k = k.replace("attn2", "attn2.processor")
+ lora_state_dict[k] = v
+ else:
+ lora_state_dict[k] = v
+ if alpha_dict:
+ args.lora_alpha = next(iter(alpha_dict.values()))
+ else:
+ args.lora_alpha = 1
+ logger.info(f"Loaded LCM LoRA with alpha: {args.lora_alpha}")
+ # Create LoRA config, FIXME: now hard-coded.
+ lora_target_modules = [
+ "to_q",
+ "to_kv",
+ "0.to_out",
+ "attn1.to_k",
+ "attn1.to_v",
+ "to_k_ip",
+ "to_v_ip",
+ "ln_k_ip.linear",
+ "ln_v_ip.linear",
+ "to_out.0",
+ "proj_in",
+ "proj_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "conv1",
+ "conv2",
+ "conv_shortcut",
+ "downsamplers.0.conv",
+ "upsamplers.0.conv",
+ "time_emb_proj",
+ ]
+ lora_config = LoraConfig(
+ r=args.lora_rank,
+ target_modules=lora_target_modules,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ )
+
+ unet.add_adapter(lora_config)
+ if args.pretrained_lcm_lora_path is not None:
+ incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
+ if unexpected_keys:
+ raise ValueError(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ for k in missing_keys:
+ if "lora" in k:
+ raise ValueError(
+ f"Loading adapter weights from state_dict led to missing keys: {missing_keys}. "
+ )
+ lcm_scheduler = LCMSingleStepScheduler.from_config(noise_scheduler.config)
+
+ # Initialize training state.
+ vae.requires_grad_(False)
+ image_encoder.requires_grad_(False)
+ text_encoder.requires_grad_(False)
+ text_encoder_2.requires_grad_(False)
+ unet.requires_grad_(False)
+ aggregator.requires_grad_(False)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # `accelerate` 0.16.0 will have better support for customized saving
+ if args.save_only_adapter:
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ for model in models:
+ if isinstance(model, Aggregator):
+ torch.save(model.state_dict(), os.path.join(output_dir, "aggregator_ckpt.pt"))
+ weights.pop()
+
+ def load_model_hook(models, input_dir):
+
+ while len(models) > 0:
+ # pop models so that they are not loaded again
+ model = models.pop()
+
+ if isinstance(model, Aggregator):
+ aggregator_state_dict = torch.load(os.path.join(input_dir, "aggregator_ckpt.pt"), map_location="cpu")
+ model.load_state_dict(aggregator_state_dict)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
+ # as these models are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ import xformers
+
+ xformers_version = version.parse(xformers.__version__)
+ if xformers_version == version.parse("0.0.16"):
+ logger.warning(
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co./docs/diffusers/main/en/optimization/xformers for more details."
+ )
+ unet.enable_xformers_memory_efficient_attention()
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ if args.gradient_checkpointing:
+ aggregator.enable_gradient_checkpointing()
+ unet.enable_gradient_checkpointing()
+
+ # Check that all trainable models are in full precision
+ low_precision_error_string = (
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
+ " doing mixed precision training, copy of the weights should still be float32."
+ )
+
+ if unwrap_model(aggregator).dtype != torch.float32:
+ raise ValueError(
+ f"aggregator loaded as datatype {unwrap_model(aggregator).dtype}. {low_precision_error_string}"
+ )
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32:
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # Optimizer creation
+ ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
+ if args.optimize_adapter:
+ unet_params = ip_params
+ unet_frozen_params = non_ip_params
+ else: # freeze all unet params
+ unet_params = []
+ unet_frozen_params = ip_params + non_ip_params
+ assert len(unet_frozen_params) == len(list(unet.parameters()))
+ params_to_optimize = [p for p in aggregator.parameters()]
+ params_to_optimize += unet_params
+ optimizer = optimizer_class(
+ params_to_optimize,
+ lr=args.learning_rate,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ # Instantiate Loss.
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
+ diffusion_losses = list()
+ lcm_losses = list()
+ for loss_config in losses_configs.diffusion_losses:
+ logger.info(f"Using diffusion loss: {loss_config.name}")
+ loss = namedtuple("loss", ["loss", "weight"])
+ diffusion_losses.append(
+ loss(loss=eval(loss_config.name)(
+ visualize_every_k=loss_config.visualize_every_k,
+ dtype=weight_dtype,
+ accelerator=accelerator,
+ **loss_config.init_params), weight=loss_config.weight)
+ )
+ for loss_config in losses_configs.lcm_losses:
+ logger.info(f"Using lcm loss: {loss_config.name}")
+ loss = namedtuple("loss", ["loss", "weight"])
+ loss_class = eval(loss_config.name)
+ lcm_losses.append(loss(loss=loss_class(visualize_every_k=loss_config.visualize_every_k,
+ dtype=weight_dtype,
+ accelerator=accelerator,
+ dino_model=image_encoder,
+ dino_preprocess=image_processor,
+ **loss_config.init_params), weight=loss_config.weight))
+
+ # SDXL additional condition that will be added to time embedding.
+ def compute_time_ids(original_size, crops_coords_top_left):
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
+ target_size = (args.resolution, args.resolution)
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_time_ids = torch.tensor([add_time_ids])
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
+ return add_time_ids
+
+ # Text prompt embeddings.
+ @torch.no_grad()
+ def compute_embeddings(batch, text_encoders, tokenizers, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
+ prompt_batch = batch[args.caption_column]
+ if drop_idx is not None:
+ for i in range(len(prompt_batch)):
+ prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
+ prompt_batch, text_encoders, tokenizers, is_train
+ )
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+ unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
+
+ return prompt_embeds, unet_added_cond_kwargs
+
+ @torch.no_grad()
+ def get_added_cond(batch, prompt_embeds, pooled_prompt_embeds, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
+
+ add_time_ids = torch.cat(
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
+ )
+
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
+ unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
+
+ return prompt_embeds, unet_added_cond_kwargs
+
+ # Move pixels into latents.
+ @torch.no_grad()
+ def convert_to_latent(pixels):
+ model_input = vae.encode(pixels).latent_dist.sample()
+ model_input = model_input * vae.config.scaling_factor
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ return model_input
+
+ # Helper functions for training loop.
+ # if args.degradation_config_path is not None:
+ # with open(args.degradation_config_path) as stream:
+ # degradation_configs = yaml.safe_load(stream)
+ # deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
+ # else:
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
+ compute_embeddings_fn = functools.partial(
+ compute_embeddings,
+ text_encoders=[text_encoder, text_encoder_2],
+ tokenizers=[tokenizer, tokenizer_2],
+ is_train=True,
+ )
+
+ datasets = []
+ datasets_name = []
+ datasets_weights = []
+ if args.data_config_path is not None:
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
+ for single_dataset in data_config.datasets:
+ datasets_weights.append(single_dataset.dataset_weight)
+ datasets_name.append(single_dataset.dataset_folder)
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
+ datasets.append(image_dataset)
+ # TODO: Validation dataset
+ if data_config.val_dataset is not None:
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
+
+ # Mix training datasets.
+ sampler_train = None
+ if len(datasets) == 1:
+ train_dataset = datasets[0]
+ else:
+ # Weighted each dataset
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
+ dataset_weights = []
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
+ sampler_train = torch.utils.data.WeightedRandomSampler(
+ weights=dataset_weights,
+ num_samples=len(dataset_weights)
+ )
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_size=args.train_batch_size,
+ sampler=sampler_train,
+ shuffle=True if sampler_train is None else False,
+ collate_fn=collate_fn,
+ num_workers=args.dataloader_num_workers
+ )
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_config = dict(vars(args))
+
+ # tensorboard cannot handle list types for config
+ tracker_config.pop("validation_prompt")
+ tracker_config.pop("validation_image")
+
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
+
+ # Scheduler and math around the number of training steps.
+ overrode_max_train_steps = False
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ overrode_max_train_steps = True
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
+ num_training_steps=args.max_train_steps,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ aggregator, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ aggregator, unet, optimizer, train_dataloader, lr_scheduler
+ )
+
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
+
+ # # cache empty prompts and move text encoders to cpu
+ # empty_prompt_embeds, empty_pooled_prompt_embeds = encode_prompt(
+ # [""]*args.train_batch_size, [text_encoder, text_encoder_2], [tokenizer, tokenizer_2], True
+ # )
+ # compute_embeddings_fn = functools.partial(
+ # get_added_cond,
+ # prompt_embeds=empty_prompt_embeds,
+ # pooled_prompt_embeds=empty_pooled_prompt_embeds,
+ # is_train=True,
+ # )
+ # text_encoder.to("cpu")
+ # text_encoder_2.to("cpu")
+
+ # Move vae, unet and text_encoder to device and cast to `weight_dtype`.
+ if args.pretrained_vae_model_name_or_path is None:
+ # The VAE is fp32 to avoid NaN losses.
+ vae.to(accelerator.device, dtype=torch.float32)
+ else:
+ vae.to(accelerator.device, dtype=weight_dtype)
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
+ if args.use_ema_adapter:
+ # FIXME: prepare ema model
+ # ema_encoder_hid_proj, ema_attn_procs_list = accelerator.prepare(ema_encoder_hid_proj, ema_attn_procs_list)
+ ema_encoder_hid_proj.to(accelerator.device)
+ ema_attn_procs_list.to(accelerator.device)
+ for param in unet_frozen_params:
+ param.data = param.data.to(dtype=weight_dtype)
+ for param in unet_params:
+ param.requires_grad_(True)
+ unet.to(accelerator.device)
+ aggregator.requires_grad_(True)
+ aggregator.to(accelerator.device)
+ importance_ratio = importance_ratio.to(accelerator.device)
+
+ # Final sanity check.
+ for n, p in unet.named_parameters():
+ assert not p.requires_grad, n
+ if p.requires_grad:
+ assert p.dtype == torch.float32, n
+ else:
+ assert p.dtype == weight_dtype, n
+ for n, p in aggregator.named_parameters():
+ if p.requires_grad: assert p.dtype == torch.float32, n
+ else:
+ raise RuntimeError(f"All parameters in aggregator should require grad. {n}")
+ assert p.dtype == weight_dtype, n
+ unwrap_model(unet).disable_adapters()
+
+ if args.sanity_check:
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+
+ if args.use_ema_adapter:
+ unwrap_model(unet).set_attn_processor(ema_attn_procs)
+ unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
+ batch = next(iter(train_dataloader))
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
+ log_validation(
+ unwrap_model(unet), unwrap_model(aggregator), vae,
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, log_local=True
+ )
+ if args.use_ema_adapter:
+ unwrap_model(unet).set_attn_processor(orig_attn_procs)
+ unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
+ for n, p in unet.named_parameters():
+ if p.requires_grad: assert p.dtype == torch.float32, n
+ else: assert p.dtype == weight_dtype, n
+ for n, p in aggregator.named_parameters():
+ if p.requires_grad: assert p.dtype == torch.float32, n
+ else: assert p.dtype == weight_dtype, n
+ print("PASS")
+ exit()
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if overrode_max_train_steps:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the most recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ trainable_models = [aggregator, unet]
+
+ if args.gradient_checkpointing:
+ # TODO: add vae
+ checkpoint_models = []
+ else:
+ checkpoint_models = []
+
+ image_logs = None
+ tic = time.time()
+ for epoch in range(first_epoch, args.num_train_epochs):
+ for step, batch in enumerate(train_dataloader):
+ toc = time.time()
+ io_time = toc - tic
+ tic = time.time()
+ for model in trainable_models + checkpoint_models:
+ model.train()
+ with accelerator.accumulate(*trainable_models):
+ loss = torch.tensor(0.0)
+
+ # Drop conditions.
+ rand_tensor = torch.rand(batch["images"].shape[0])
+ drop_image_idx = rand_tensor < args.image_drop_rate
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
+ drop_image_idx = drop_image_idx | drop_both_idx
+ drop_text_idx = drop_text_idx | drop_both_idx
+
+ # Get LQ embeddings
+ with torch.no_grad():
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
+ lq_pt = image_processor(
+ images=lq_img*0.5+0.5,
+ do_rescale=False, return_tensors="pt"
+ ).pixel_values
+
+ # Move inputs to latent space.
+ gt_img = gt_img.to(dtype=vae.dtype)
+ lq_img = lq_img.to(dtype=vae.dtype)
+ model_input = convert_to_latent(gt_img)
+ lq_latent = convert_to_latent(lq_img)
+ if args.pretrained_vae_model_name_or_path is None:
+ model_input = model_input.to(weight_dtype)
+ lq_latent = lq_latent.to(weight_dtype)
+
+ # Process conditions.
+ image_embeds = prepare_training_image_embeds(
+ image_encoder, image_processor,
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
+ idx_to_replace=drop_image_idx
+ )
+ prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
+
+ # Sample noise that we'll add to the latents.
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
+
+ # Add noise to the model input according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
+ loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
+
+ if args.CFG_scale > 1.0:
+ # Process negative conditions.
+ drop_idx = torch.ones_like(drop_image_idx)
+ neg_image_embeds = prepare_training_image_embeds(
+ image_encoder, image_processor,
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
+ device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
+ idx_to_replace=drop_idx
+ )
+ neg_prompt_embeds_input, neg_added_conditions = compute_embeddings_fn(batch, drop_idx=drop_idx)
+ previewer_model_input = torch.cat([noisy_model_input] * 2)
+ previewer_timesteps = torch.cat([timesteps] * 2)
+ previewer_prompt_embeds = torch.cat([neg_prompt_embeds_input, prompt_embeds_input], dim=0)
+ previewer_added_conditions = {}
+ for k in added_conditions.keys():
+ previewer_added_conditions[k] = torch.cat([neg_added_conditions[k], added_conditions[k]], dim=0)
+ previewer_image_embeds = []
+ for neg_image_embed, image_embed in zip(neg_image_embeds, image_embeds):
+ previewer_image_embeds.append(torch.cat([neg_image_embed, image_embed], dim=0))
+ previewer_added_conditions["image_embeds"] = previewer_image_embeds
+ else:
+ previewer_model_input = noisy_model_input
+ previewer_timesteps = timesteps
+ previewer_prompt_embeds = prompt_embeds_input
+ previewer_added_conditions = {}
+ for k in added_conditions.keys():
+ previewer_added_conditions[k] = added_conditions[k]
+ previewer_added_conditions["image_embeds"] = image_embeds
+
+ # Get LCM preview latent
+ if args.use_ema_adapter:
+ orig_encoder_hid_proj = unet.encoder_hid_proj
+ orig_attn_procs = unet.attn_processors
+ _ema_attn_procs = copy_dict(ema_attn_procs)
+ unwrap_model(unet).set_attn_processor(_ema_attn_procs)
+ unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
+ unwrap_model(unet).enable_adapters()
+ preview_noise = unet(
+ previewer_model_input, previewer_timesteps,
+ encoder_hidden_states=previewer_prompt_embeds,
+ added_cond_kwargs=previewer_added_conditions,
+ return_dict=False
+ )[0]
+ if args.CFG_scale > 1.0:
+ preview_noise_uncond, preview_noise_cond = preview_noise.chunk(2)
+ cfg_scale = 1.0 + torch.rand_like(timesteps, dtype=weight_dtype) * (args.CFG_scale-1.0)
+ cfg_scale = cfg_scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ preview_noise = preview_noise_uncond + cfg_scale * (preview_noise_cond - preview_noise_uncond)
+ preview_latents = lcm_scheduler.step(
+ preview_noise,
+ timesteps,
+ noisy_model_input,
+ return_dict=False
+ )[0]
+ unwrap_model(unet).disable_adapters()
+ if args.use_ema_adapter:
+ unwrap_model(unet).set_attn_processor(orig_attn_procs)
+ unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
+ preview_error_latent = F.mse_loss(preview_latents, model_input).cpu().item()
+ preview_error_noise = F.mse_loss(preview_noise, noise).cpu().item()
+
+ # # Add fresh noise
+ # if args.noisy_encoder_input:
+ # aggregator_noise = torch.randn_like(preview_latents)
+ # aggregator_input = noise_scheduler.add_noise(preview_latents, aggregator_noise, timesteps)
+
+ down_block_res_samples, mid_block_res_sample = aggregator(
+ lq_latent,
+ timesteps,
+ encoder_hidden_states=prompt_embeds_input,
+ added_cond_kwargs=added_conditions,
+ controlnet_cond=preview_latents,
+ conditioning_scale=1.0,
+ return_dict=False,
+ )
+
+ # UNet denoise.
+ added_conditions["image_embeds"] = image_embeds
+ model_pred = unet(
+ noisy_model_input,
+ timesteps,
+ encoder_hidden_states=prompt_embeds_input,
+ added_cond_kwargs=added_conditions,
+ down_block_additional_residuals=[
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
+ ],
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
+ return_dict=False
+ )[0]
+
+ diffusion_loss_arguments = {
+ "target": noise,
+ "predict": model_pred,
+ "prompt_embeddings_input": prompt_embeds_input,
+ "timesteps": timesteps,
+ "weights": loss_weights,
+ }
+
+ loss_dict = dict()
+ for loss_config in diffusion_losses:
+ non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
+ loss = loss + non_weighted_loss * loss_config.weight
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
+
+ toc = time.time()
+ forward_time = toc - tic
+ tic = toc
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if global_step % args.ema_update_steps == 0:
+ if args.use_ema_adapter:
+ update_ema_model(ema_encoder_hid_proj, orig_encoder_hid_proj, ema_beta)
+ update_ema_model(ema_attn_procs_list, orig_attn_procs_list, ema_beta)
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ if global_step % args.validation_steps == 0:
+ image_logs = log_validation(
+ unwrap_model(unet), unwrap_model(aggregator), vae,
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, global_step, lq_img.detach().clone(), gt_img.detach().clone()
+ )
+
+ logs = {}
+ logs.update(loss_dict)
+ logs.update(
+ {"preview_error_latent": preview_error_latent, "preview_error_noise": preview_error_noise,
+ "lr": lr_scheduler.get_last_lr()[0],
+ "forward_time": forward_time, "io_time": io_time}
+ )
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+ tic = time.time()
+
+ if global_step >= args.max_train_steps:
+ break
+
+ # Create the pipeline using using the trained modules and save it.
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ accelerator.save_state(save_path, safe_serialization=False)
+ # Run a final round of validation.
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
+ image_logs = None
+ if args.validation_image is not None:
+ image_logs = log_validation(
+ unwrap_model(unet), unwrap_model(aggregator), vae,
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
+ args, accelerator, weight_dtype, global_step,
+ )
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
\ No newline at end of file
diff --git a/train_stage2_aggregator.sh b/train_stage2_aggregator.sh
new file mode 100644
index 0000000000000000000000000000000000000000..768821eec2afb564be1094c10eb63632e5d741d1
--- /dev/null
+++ b/train_stage2_aggregator.sh
@@ -0,0 +1,24 @@
+# Stage 2: train aggregator
+accelerate launch --num_processes train_stage2_aggregator.py \
+ --output_dir \
+ --train_data_dir \
+ --logging_dir \
+ --pretrained_model_name_or_path \
+ --feature_extractor_path \
+ --pretrained_adapter_model_path \
+ --pretrained_lcm_lora_path \
+ --losses_config_path config_files/losses.yaml \
+ --data_config_path config_files/IR_dataset.yaml \
+ --image_drop_rate 0.0 \
+ --text_drop_rate 0.85 \
+ --cond_drop_rate 0.15 \
+ --save_only_adapter \
+ --gradient_checkpointing \
+ --mixed_precision fp16 \
+ --train_batch_size 6 \
+ --gradient_accumulation_steps 2 \
+ --learning_rate 1e-4 \
+ --lr_warmup_steps 1000 \
+ --lr_scheduler cosine \
+ --lr_num_cycles 1 \
+ --resume_from_checkpoint latest
\ No newline at end of file
diff --git a/utils/degradation_pipeline.py b/utils/degradation_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cd1ca9b6b8df12003d7a14862f8f795bd5c83c0
--- /dev/null
+++ b/utils/degradation_pipeline.py
@@ -0,0 +1,353 @@
+import cv2
+import math
+import numpy as np
+import random
+import torch
+from torch.utils import data as data
+
+from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
+from basicsr.data.transforms import augment
+from basicsr.utils import img2tensor, DiffJPEG, USMSharp
+from basicsr.utils.img_process_util import filter2D
+from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
+from basicsr.data.transforms import paired_random_crop
+
+AUGMENT_OPT = {
+ 'use_hflip': False,
+ 'use_rot': False
+}
+
+KERNEL_OPT = {
+ 'blur_kernel_size': 21,
+ 'kernel_list': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
+ 'kernel_prob': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
+ 'sinc_prob': 0.1,
+ 'blur_sigma': [0.2, 3],
+ 'betag_range': [0.5, 4],
+ 'betap_range': [1, 2],
+
+ 'blur_kernel_size2': 21,
+ 'kernel_list2': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
+ 'kernel_prob2': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
+ 'sinc_prob2': 0.1,
+ 'blur_sigma2': [0.2, 1.5],
+ 'betag_range2': [0.5, 4],
+ 'betap_range2': [1, 2],
+ 'final_sinc_prob': 0.8,
+}
+
+DEGRADE_OPT = {
+ 'resize_prob': [0.2, 0.7, 0.1], # up, down, keep
+ 'resize_range': [0.15, 1.5],
+ 'gaussian_noise_prob': 0.5,
+ 'noise_range': [1, 30],
+ 'poisson_scale_range': [0.05, 3],
+ 'gray_noise_prob': 0.4,
+ 'jpeg_range': [30, 95],
+
+ # the second degradation process
+ 'second_blur_prob': 0.8,
+ 'resize_prob2': [0.3, 0.4, 0.3], # up, down, keep
+ 'resize_range2': [0.3, 1.2],
+ 'gaussian_noise_prob2': 0.5,
+ 'noise_range2': [1, 25],
+ 'poisson_scale_range2': [0.05, 2.5],
+ 'gray_noise_prob2': 0.4,
+ 'jpeg_range2': [30, 95],
+
+ 'gt_size': 512,
+ 'no_degradation_prob': 0.01,
+ 'use_usm': True,
+ 'sf': 4,
+ 'random_size': False,
+ 'resize_lq': True
+}
+
+class RealESRGANDegradation:
+
+ def __init__(self, augment_opt=None, kernel_opt=None, degrade_opt=None, device='cuda', resolution=None):
+ if augment_opt is None:
+ augment_opt = AUGMENT_OPT
+ self.augment_opt = augment_opt
+ if kernel_opt is None:
+ kernel_opt = KERNEL_OPT
+ self.kernel_opt = kernel_opt
+ if degrade_opt is None:
+ degrade_opt = DEGRADE_OPT
+ self.degrade_opt = degrade_opt
+ if resolution is not None:
+ self.degrade_opt['gt_size'] = resolution
+ self.device = device
+
+ self.jpeger = DiffJPEG(differentiable=False).to(self.device)
+ self.usm_sharpener = USMSharp().to(self.device)
+
+ # blur settings for the first degradation
+ self.blur_kernel_size = kernel_opt['blur_kernel_size']
+ self.kernel_list = kernel_opt['kernel_list']
+ self.kernel_prob = kernel_opt['kernel_prob'] # a list for each kernel probability
+ self.blur_sigma = kernel_opt['blur_sigma']
+ self.betag_range = kernel_opt['betag_range'] # betag used in generalized Gaussian blur kernels
+ self.betap_range = kernel_opt['betap_range'] # betap used in plateau blur kernels
+ self.sinc_prob = kernel_opt['sinc_prob'] # the probability for sinc filters
+
+ # blur settings for the second degradation
+ self.blur_kernel_size2 = kernel_opt['blur_kernel_size2']
+ self.kernel_list2 = kernel_opt['kernel_list2']
+ self.kernel_prob2 = kernel_opt['kernel_prob2']
+ self.blur_sigma2 = kernel_opt['blur_sigma2']
+ self.betag_range2 = kernel_opt['betag_range2']
+ self.betap_range2 = kernel_opt['betap_range2']
+ self.sinc_prob2 = kernel_opt['sinc_prob2']
+
+ # a final sinc filter
+ self.final_sinc_prob = kernel_opt['final_sinc_prob']
+
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
+ # TODO: kernel range is now hard-coded, should be in the configure file
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
+ self.pulse_tensor[10, 10] = 1
+
+ def get_kernel(self):
+
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.kernel_opt['sinc_prob']:
+ # this sinc filter setting is for kernels ranging from [7, 21]
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel = random_mixed_kernels(
+ self.kernel_list,
+ self.kernel_prob,
+ kernel_size,
+ self.blur_sigma,
+ self.blur_sigma, [-math.pi, math.pi],
+ self.betag_range,
+ self.betap_range,
+ noise_range=None)
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
+ kernel_size = random.choice(self.kernel_range)
+ if np.random.uniform() < self.kernel_opt['sinc_prob2']:
+ if kernel_size < 13:
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ else:
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
+ else:
+ kernel2 = random_mixed_kernels(
+ self.kernel_list2,
+ self.kernel_prob2,
+ kernel_size,
+ self.blur_sigma2,
+ self.blur_sigma2, [-math.pi, math.pi],
+ self.betag_range2,
+ self.betap_range2,
+ noise_range=None)
+
+ # pad kernel
+ pad_size = (21 - kernel_size) // 2
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
+
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
+ if np.random.uniform() < self.kernel_opt['final_sinc_prob']:
+ kernel_size = random.choice(self.kernel_range)
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
+ else:
+ sinc_kernel = self.pulse_tensor
+
+ # BGR to RGB, HWC to CHW, numpy to tensor
+ kernel = torch.FloatTensor(kernel)
+ kernel2 = torch.FloatTensor(kernel2)
+
+ return (kernel, kernel2, sinc_kernel)
+
+ @torch.no_grad()
+ def __call__(self, img_gt, kernels=None):
+ '''
+ :param: img_gt: BCHW, RGB, [0, 1] float32 tensor
+ '''
+ if kernels is None:
+ kernel = []
+ kernel2 = []
+ sinc_kernel = []
+ for _ in range(img_gt.shape[0]):
+ k, k2, sk = self.get_kernel()
+ kernel.append(k)
+ kernel2.append(k2)
+ sinc_kernel.append(sk)
+ kernel = torch.stack(kernel)
+ kernel2 = torch.stack(kernel2)
+ sinc_kernel = torch.stack(sinc_kernel)
+ else:
+ # kernels created in dataset.
+ kernel, kernel2, sinc_kernel = kernels
+
+ # ----------------------- Pre-process ----------------------- #
+ im_gt = img_gt.to(self.device)
+ if self.degrade_opt['use_usm']:
+ im_gt = self.usm_sharpener(im_gt)
+ im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
+ kernel = kernel.to(self.device)
+ kernel2 = kernel2.to(self.device)
+ sinc_kernel = sinc_kernel.to(self.device)
+ ori_h, ori_w = im_gt.size()[2:4]
+
+ # ----------------------- The first degradation process ----------------------- #
+ # blur
+ out = filter2D(im_gt, kernel)
+ # random resize
+ updown_type = random.choices(
+ ['up', 'down', 'keep'],
+ self.degrade_opt['resize_prob'],
+ )[0]
+ if updown_type == 'up':
+ scale = random.uniform(1, self.degrade_opt['resize_range'][1])
+ elif updown_type == 'down':
+ scale = random.uniform(self.degrade_opt['resize_range'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = torch.nn.functional.interpolate(out, scale_factor=scale, mode=mode)
+ # add noise
+ gray_noise_prob = self.degrade_opt['gray_noise_prob']
+ if random.random() < self.degrade_opt['gaussian_noise_prob']:
+ out = random_add_gaussian_noise_pt(
+ out,
+ sigma_range=self.degrade_opt['noise_range'],
+ clip=True,
+ rounds=False,
+ gray_prob=gray_noise_prob,
+ )
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.degrade_opt['poisson_scale_range'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range'])
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
+ out = self.jpeger(out, quality=jpeg_p)
+
+ # ----------------------- The second degradation process ----------------------- #
+ # blur
+ if random.random() < self.degrade_opt['second_blur_prob']:
+ out = out.contiguous()
+ out = filter2D(out, kernel2)
+ # random resize
+ updown_type = random.choices(
+ ['up', 'down', 'keep'],
+ self.degrade_opt['resize_prob2'],
+ )[0]
+ if updown_type == 'up':
+ scale = random.uniform(1, self.degrade_opt['resize_range2'][1])
+ elif updown_type == 'down':
+ scale = random.uniform(self.degrade_opt['resize_range2'][0], 1)
+ else:
+ scale = 1
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = torch.nn.functional.interpolate(
+ out,
+ size=(int(ori_h / self.degrade_opt['sf'] * scale),
+ int(ori_w / self.degrade_opt['sf'] * scale)),
+ mode=mode,
+ )
+ # add noise
+ gray_noise_prob = self.degrade_opt['gray_noise_prob2']
+ if random.random() < self.degrade_opt['gaussian_noise_prob2']:
+ out = random_add_gaussian_noise_pt(
+ out,
+ sigma_range=self.degrade_opt['noise_range2'],
+ clip=True,
+ rounds=False,
+ gray_prob=gray_noise_prob,
+ )
+ else:
+ out = random_add_poisson_noise_pt(
+ out,
+ scale_range=self.degrade_opt['poisson_scale_range2'],
+ gray_prob=gray_noise_prob,
+ clip=True,
+ rounds=False,
+ )
+
+ # JPEG compression + the final sinc filter
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
+ # as one operation.
+ # We consider two orders:
+ # 1. [resize back + sinc filter] + JPEG compression
+ # 2. JPEG compression + [resize back + sinc filter]
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
+ if random.random() < 0.5:
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = torch.nn.functional.interpolate(
+ out,
+ size=(ori_h // self.degrade_opt['sf'],
+ ori_w // self.degrade_opt['sf']),
+ mode=mode,
+ )
+ out = out.contiguous()
+ out = filter2D(out, sinc_kernel)
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ else:
+ # JPEG compression
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
+ out = torch.clamp(out, 0, 1)
+ out = self.jpeger(out, quality=jpeg_p)
+ # resize back + the final sinc filter
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
+ out = torch.nn.functional.interpolate(
+ out,
+ size=(ori_h // self.degrade_opt['sf'],
+ ori_w // self.degrade_opt['sf']),
+ mode=mode,
+ )
+ out = out.contiguous()
+ out = filter2D(out, sinc_kernel)
+
+ # clamp and round
+ im_lq = torch.clamp(out, 0, 1.0)
+
+ # random crop
+ gt_size = self.degrade_opt['gt_size']
+ im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.degrade_opt['sf'])
+
+ if self.degrade_opt['resize_lq']:
+ im_lq = torch.nn.functional.interpolate(
+ im_lq,
+ size=(im_gt.size(-2),
+ im_gt.size(-1)),
+ mode='bicubic',
+ )
+
+ if random.random() < self.degrade_opt['no_degradation_prob'] or torch.isnan(im_lq).any():
+ im_lq = im_gt
+
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
+ im_lq = im_lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
+ im_lq = im_lq*2 - 1.0
+ im_gt = im_gt*2 - 1.0
+
+ if self.degrade_opt['random_size']:
+ raise NotImplementedError
+ im_lq, im_gt = self.randn_cropinput(im_lq, im_gt)
+
+ im_lq = torch.clamp(im_lq, -1.0, 1.0)
+ im_gt = torch.clamp(im_gt, -1.0, 1.0)
+
+ return (im_lq, im_gt)
\ No newline at end of file
diff --git a/utils/matlab_cp2tform.py b/utils/matlab_cp2tform.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdcdf96ab45577bcea54c809b4d241c9e8a6a74e
--- /dev/null
+++ b/utils/matlab_cp2tform.py
@@ -0,0 +1,350 @@
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Jul 11 06:54:28 2017
+
+@author: zhaoyafei
+"""
+
+import numpy as np
+from numpy.linalg import inv, norm, lstsq
+from numpy.linalg import matrix_rank as rank
+
+class MatlabCp2tormException(Exception):
+ def __str__(self):
+ return 'In File {}:{}'.format(
+ __file__, super.__str__(self))
+
+def tformfwd(trans, uv):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+ uv = np.hstack((
+ uv, np.ones((uv.shape[0], 1))
+ ))
+ xy = np.dot(uv, trans)
+ xy = xy[:, 0:-1]
+ return xy
+
+
+def tforminv(trans, uv):
+ """
+ Function:
+ ----------
+ apply the inverse of affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of inverse-transformed coordinates (x, y)
+ """
+ Tinv = inv(trans)
+ xy = tformfwd(Tinv, uv)
+ return xy
+
+
+def findNonreflectiveSimilarity(uv, xy, options=None):
+
+ options = {'K': 2}
+
+ K = options['K']
+ M = xy.shape[0]
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+ # print('--->x, y:\n', x, y
+
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
+ X = np.vstack((tmp1, tmp2))
+ # print('--->X.shape: ', X.shape
+ # print('X:\n', X
+
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+ U = np.vstack((u, v))
+ # print('--->U.shape: ', U.shape
+ # print('U:\n', U
+
+ # We know that X * r = U
+ if rank(X) >= 2 * K:
+ r, _, _, _ = lstsq(X, U)
+ r = np.squeeze(r)
+ else:
+ raise Exception('cp2tform:twoUniquePointsReq')
+
+ # print('--->r:\n', r
+
+ sc = r[0]
+ ss = r[1]
+ tx = r[2]
+ ty = r[3]
+
+ Tinv = np.array([
+ [sc, -ss, 0],
+ [ss, sc, 0],
+ [tx, ty, 1]
+ ])
+
+ # print('--->Tinv:\n', Tinv
+
+ T = inv(Tinv)
+ # print('--->T:\n', T
+
+ T[:, 2] = np.array([0, 0, 1])
+
+ return T, Tinv
+
+
+def findSimilarity(uv, xy, options=None):
+
+ options = {'K': 2}
+
+# uv = np.array(uv)
+# xy = np.array(xy)
+
+ # Solve for trans1
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
+
+ # Solve for trans2
+
+ # manually reflect the xy data across the Y-axis
+ xyR = xy
+ xyR[:, 0] = -1 * xyR[:, 0]
+
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
+
+ # manually reflect the tform to undo the reflection done on xyR
+ TreflectY = np.array([
+ [-1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]
+ ])
+
+ trans2 = np.dot(trans2r, TreflectY)
+
+ # Figure out if trans1 or trans2 is better
+ xy1 = tformfwd(trans1, uv)
+ norm1 = norm(xy1 - xy)
+
+ xy2 = tformfwd(trans2, uv)
+ norm2 = norm(xy2 - xy)
+
+ if norm1 <= norm2:
+ return trans1, trans1_inv
+ else:
+ trans2_inv = inv(trans2)
+ return trans2, trans2_inv
+
+
+def get_similarity_transform(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'trans':
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y, 1] = [u, v, 1] * trans
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ @reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+ trans_inv: 3x3 np.array
+ inverse of trans, transform matrix from xy to uv
+ """
+
+ if reflective:
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
+ else:
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
+
+ return trans, trans_inv
+
+
+def cvt_tform_mat_for_cv2(trans):
+ """
+ Function:
+ ----------
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ cv2_trans = trans[:, 0:2].T
+
+ return cv2_trans
+
+
+def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
+
+ return cv2_trans
+
+
+if __name__ == '__main__':
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ # In Matlab, run:
+ #
+ # uv = [u'; v'];
+ # xy = [x'; y'];
+ # tform_sim=cp2tform(uv,xy,'similarity');
+ #
+ # trans = tform_sim.tdata.T
+ # ans =
+ # -0.0764 -1.6190 0
+ # 1.6190 -0.0764 0
+ # -3.2156 0.0290 1.0000
+ # trans_inv = tform_sim.tdata.Tinv
+ # ans =
+ #
+ # -0.0291 0.6163 0
+ # -0.6163 -0.0291 0
+ # -0.0756 1.9826 1.0000
+ # xy_m=tformfwd(tform_sim, u,v)
+ #
+ # xy_m =
+ #
+ # -3.2156 0.0290
+ # 1.1833 -9.9143
+ # 5.0323 2.8853
+ # uv_m=tforminv(tform_sim, x,y)
+ #
+ # uv_m =
+ #
+ # 0.5698 1.3953
+ # 6.0872 2.2733
+ # -2.6570 4.3314
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ uv = np.array((u, v)).T
+ xy = np.array((x, y)).T
+
+ print('\n--->uv:')
+ print(uv)
+ print('\n--->xy:')
+ print(xy)
+
+ trans, trans_inv = get_similarity_transform(uv, xy)
+
+ print('\n--->trans matrix:')
+ print(trans)
+
+ print('\n--->trans_inv matrix:')
+ print(trans_inv)
+
+ print('\n---> apply transform to uv')
+ print('\nxy_m = uv_augmented * trans')
+ uv_aug = np.hstack((
+ uv, np.ones((uv.shape[0], 1))
+ ))
+ xy_m = np.dot(uv_aug, trans)
+ print(xy_m)
+
+ print('\nxy_m = tformfwd(trans, uv)')
+ xy_m = tformfwd(trans, uv)
+ print(xy_m)
+
+ print('\n---> apply inverse transform to xy')
+ print('\nuv_m = xy_augmented * trans_inv')
+ xy_aug = np.hstack((
+ xy, np.ones((xy.shape[0], 1))
+ ))
+ uv_m = np.dot(xy_aug, trans_inv)
+ print(uv_m)
+
+ print('\nuv_m = tformfwd(trans_inv, xy)')
+ uv_m = tformfwd(trans_inv, xy)
+ print(uv_m)
+
+ uv_m = tforminv(trans, xy)
+ print('\nuv_m = tforminv(trans, xy)')
+ print(uv_m)
diff --git a/utils/parser.py b/utils/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..3830be1212234577d6cd86a293b1285269008537
--- /dev/null
+++ b/utils/parser.py
@@ -0,0 +1,452 @@
+import argparse
+import os
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Train Consistency Encoder.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--pretrained_vae_model_name_or_path",
+ type=str,
+ default=None,
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+
+ # parser.add_argument(
+ # "--instance_data_dir",
+ # type=str,
+ # required=True,
+ # help=("A folder containing the training data. "),
+ # )
+
+ parser.add_argument(
+ "--data_config_path",
+ type=str,
+ required=True,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ required=True,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_train_vis_images",
+ type=int,
+ default=2,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=2,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+
+ parser.add_argument(
+ "--validation_vis_steps",
+ type=int,
+ default=500,
+ help=(
+ "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+
+ parser.add_argument(
+ "--train_vis_steps",
+ type=int,
+ default=500,
+ help=(
+ "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+
+ parser.add_argument(
+ "--vis_lcm",
+ type=bool,
+ default=True,
+ help=(
+ "Also log results of LCM inference",
+ ),
+ )
+
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="lora-dreambooth-model",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+
+ parser.add_argument("--save_only_encoder", action="store_true", help="Only save the encoder and not the full accelerator state")
+
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+
+ parser.add_argument("--freeze_encoder_unet", action="store_true", help="Don't train encoder unet")
+ parser.add_argument("--predict_word_embedding", action="store_true", help="Predict word embeddings in addition to KV features")
+ parser.add_argument("--ip_adapter_feature_extractor_path", type=str, help="Path to pre-trained feature extractor for IP-adapter")
+ parser.add_argument("--ip_adapter_model_path", type=str, help="Path to pre-trained IP-adapter.")
+ parser.add_argument("--ip_adapter_tokens", type=int, default=16, help="Number of tokens to use in IP-adapter cross attention mechanism")
+ parser.add_argument("--optimize_adapter", action="store_true", help="Optimize IP-adapter parameters (projector + cross-attention layers)")
+ parser.add_argument("--adapter_attention_scale", type=float, default=1.0, help="Relative strength of the adapter cross attention layers")
+ parser.add_argument("--adapter_lr", type=float, help="Learning rate for the adapter parameters. Defaults to the global LR if not provided")
+
+ parser.add_argument("--noisy_encoder_input", action="store_true", help="Noise the encoder input to the same step as the decoder?")
+
+ # related to CFG:
+ parser.add_argument("--adapter_drop_chance", type=float, default=0.0, help="Chance to drop adapter condition input during training")
+ parser.add_argument("--text_drop_chance", type=float, default=0.0, help="Chance to drop text condition during training")
+ parser.add_argument("--kv_drop_chance", type=float, default=0.0, help="Chance to drop KV condition during training")
+
+
+
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=1024,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+
+ parser.add_argument(
+ "--crops_coords_top_left_h",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+
+ parser.add_argument(
+ "--crops_coords_top_left_w",
+ type=int,
+ default=0,
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
+ )
+
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=5,
+ help=("Max number of checkpoints to store."),
+ )
+
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+
+ parser.add_argument("--max_timesteps_for_x0_loss", type=int, default=1001)
+
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+
+ parser.add_argument(
+ "--snr_gamma",
+ type=float,
+ default=None,
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
+ "More details here: https://arxiv.org/abs/2303.09556.",
+ )
+
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="wandb",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ parser.add_argument(
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
+ )
+
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+
+ parser.add_argument(
+ "--pretrained_lcm_lora_path",
+ type=str,
+ default="latent-consistency/lcm-lora-sdxl",
+ help=("Path for lcm lora pretrained"),
+ )
+
+ parser.add_argument(
+ "--losses_config_path",
+ type=str,
+ required=True,
+ help=("A yaml file containing losses to use and their weights."),
+ )
+
+ parser.add_argument(
+ "--lcm_every_k_steps",
+ type=int,
+ default=-1,
+ help="How often to run lcm. If -1, lcm is not run."
+ )
+
+ parser.add_argument(
+ "--lcm_batch_size",
+ type=int,
+ default=1,
+ help="Batch size for lcm."
+ )
+ parser.add_argument(
+ "--lcm_max_timestep",
+ type=int,
+ default=1000,
+ help="Max timestep to use with LCM."
+ )
+
+ parser.add_argument(
+ "--lcm_sample_scale_every_k_steps",
+ type=int,
+ default=-1,
+ help="How often to change lcm scale. If -1, scale is fixed at 1."
+ )
+
+ parser.add_argument(
+ "--lcm_min_scale",
+ type=float,
+ default=0.1,
+ help="When sampling lcm scale, the minimum scale to use."
+ )
+
+ parser.add_argument(
+ "--scale_lcm_by_max_step",
+ action="store_true",
+ help="scale LCM lora alpha linearly by the maximal timestep sampled that iteration"
+ )
+
+ parser.add_argument(
+ "--lcm_sample_full_lcm_prob",
+ type=float,
+ default=0.2,
+ help="When sampling lcm scale, the probability of using full lcm (scale of 1)."
+ )
+
+ parser.add_argument(
+ "--run_on_cpu",
+ action="store_true",
+ help="whether to run on cpu or not"
+ )
+
+ parser.add_argument(
+ "--experiment_name",
+ type=str,
+ help=("A short description of the experiment to add to the wand run log. "),
+ )
+ parser.add_argument("--encoder_lora_rank", type=int, default=0, help="Rank of Lora in unet encoder. 0 means no lora")
+
+ parser.add_argument("--kvcopy_lora_rank", type=int, default=0, help="Rank of lora in the kvcopy modules. 0 means no lora")
+
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ args.optimizer = "AdamW"
+
+ return args
\ No newline at end of file
diff --git a/utils/text_utils.py b/utils/text_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d1326275f969ab5117f31b508c2e30b1438f893
--- /dev/null
+++ b/utils/text_utils.py
@@ -0,0 +1,76 @@
+import torch
+
+def tokenize_prompt(tokenizer, prompt):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
+def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
+ prompt_embeds_list = []
+
+ for i, text_encoder in enumerate(text_encoders):
+ if tokenizers is not None:
+ tokenizer = tokenizers[i]
+ text_input_ids = tokenize_prompt(tokenizer, prompt)
+ else:
+ assert text_input_ids_list is not None
+ text_input_ids = text_input_ids_list[i]
+
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+
+def add_tokens(tokenizers, tokens, text_encoders):
+ new_token_indices = {}
+ for idx, tokenizer in enumerate(tokenizers):
+ for token in tokens:
+ num_added_tokens = tokenizer.add_tokens(token)
+ if num_added_tokens == 0:
+ raise ValueError(
+ f"The tokenizer already contains the token {token}. Please pass a different"
+ " `placeholder_token` that is not already in the tokenizer."
+ )
+
+ new_token_indices[f"{idx}_{token}"] = num_added_tokens
+ # resize embedding layers to avoid crash. We will never actually use these.
+ text_encoders[idx].resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
+
+ return new_token_indices
+
+
+def patch_embedding_forward(embedding_layer, new_tokens, new_embeddings):
+
+ def new_forward(input):
+ embedded_text = torch.nn.functional.embedding(
+ input, embedding_layer.weight, embedding_layer.padding_idx, embedding_layer.max_norm,
+ embedding_layer.norm_type, embedding_layer.scale_grad_by_freq, embedding_layer.sparse)
+
+ replace_indices = (input == new_tokens)
+
+ if torch.count_nonzero(replace_indices) > 0:
+ embedded_text[replace_indices] = new_embeddings
+
+ return embedded_text
+
+ embedding_layer.forward = new_forward
\ No newline at end of file
diff --git a/utils/train_utils.py b/utils/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..02d6cd734dbd1cbe64b75e19f61bb5ddd782ab76
--- /dev/null
+++ b/utils/train_utils.py
@@ -0,0 +1,360 @@
+import argparse
+import contextlib
+import time
+import gc
+import logging
+import math
+import os
+import random
+import jsonlines
+import functools
+import shutil
+import pyrallis
+import itertools
+from pathlib import Path
+from collections import namedtuple, OrderedDict
+
+import accelerate
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from datasets import load_dataset
+from packaging import version
+from PIL import Image
+from losses.losses import *
+from torchvision import transforms
+from torchvision.transforms.functional import crop
+from tqdm.auto import tqdm
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ from transformers import PretrainedConfig
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "CLIPTextModelWithProjection":
+ from transformers import CLIPTextModelWithProjection
+
+ return CLIPTextModelWithProjection
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+def get_train_dataset(dataset_name, dataset_dir, args, accelerator):
+ # Get the datasets: you can either provide your own training and evaluation files (see below)
+ # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
+
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
+ # download the dataset.
+ dataset = load_dataset(
+ dataset_name,
+ data_dir=dataset_dir,
+ cache_dir=os.path.join(dataset_dir, ".cache"),
+ num_proc=4,
+ split="train",
+ )
+
+ # Preprocessing the datasets.
+ # We need to tokenize inputs and targets.
+ column_names = dataset.column_names
+
+ # 6. Get the column names for input/target.
+ if args.image_column is None:
+ args.image_column = column_names[0]
+ logger.info(f"image column defaulting to {column_names[0]}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ logger.warning(f"dataset {dataset_name} has no column {image_column}")
+
+ if args.caption_column is None:
+ args.caption_column = column_names[1]
+ logger.info(f"caption column defaulting to {column_names[1]}")
+ else:
+ caption_column = args.caption_column
+ if caption_column not in column_names:
+ logger.warning(f"dataset {dataset_name} has no column {caption_column}")
+
+ if args.conditioning_image_column is None:
+ args.conditioning_image_column = column_names[2]
+ logger.info(f"conditioning image column defaulting to {column_names[2]}")
+ else:
+ conditioning_image_column = args.conditioning_image_column
+ if conditioning_image_column not in column_names:
+ logger.warning(f"dataset {dataset_name} has no column {conditioning_image_column}")
+
+ with accelerator.main_process_first():
+ train_dataset = dataset.shuffle(seed=args.seed)
+ if args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(args.max_train_samples))
+ return train_dataset
+
+def prepare_train_dataset(dataset, accelerator, deg_pipeline, centralize=False):
+
+ # Data augmentations.
+ hflip = deg_pipeline.augment_opt['use_hflip'] and random.random() < 0.5
+ vflip = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5
+ rot90 = deg_pipeline.augment_opt['use_rot'] and random.random() < 0.5
+ augment_transforms = []
+ if hflip:
+ augment_transforms.append(transforms.RandomHorizontalFlip(p=1.0))
+ if vflip:
+ augment_transforms.append(transforms.RandomVerticalFlip(p=1.0))
+ if rot90:
+ # FIXME
+ augment_transforms.append(transforms.RandomRotation(degrees=(90,90)))
+ torch_transforms=[transforms.ToTensor()]
+ if centralize:
+ # to [-1, 1]
+ torch_transforms.append(transforms.Normalize([0.5], [0.5]))
+
+ training_size = deg_pipeline.degrade_opt['gt_size']
+ image_transforms = transforms.Compose(augment_transforms)
+ train_transforms = transforms.Compose(torch_transforms)
+ train_resize = transforms.Resize(training_size, interpolation=transforms.InterpolationMode.BILINEAR)
+ train_crop = transforms.RandomCrop(training_size)
+
+ def preprocess_train(examples):
+ raw_images = []
+ for img_data in examples[args.image_column]:
+ raw_images.append(Image.open(img_data).convert("RGB"))
+
+ # Image stack.
+ images = []
+ original_sizes = []
+ crop_top_lefts = []
+ # Degradation kernels stack.
+ kernel = []
+ kernel2 = []
+ sinc_kernel = []
+
+ for raw_image in raw_images:
+ raw_image = image_transforms(raw_image)
+ original_sizes.append((raw_image.height, raw_image.width))
+
+ # Resize smaller edge.
+ raw_image = train_resize(raw_image)
+ # Crop to training size.
+ y1, x1, h, w = train_crop.get_params(raw_image, (training_size, training_size))
+ raw_image = crop(raw_image, y1, x1, h, w)
+ crop_top_left = (y1, x1)
+ crop_top_lefts.append(crop_top_left)
+ image = train_transforms(raw_image)
+
+ images.append(image)
+ k, k2, sk = deg_pipeline.get_kernel()
+ kernel.append(k)
+ kernel2.append(k2)
+ sinc_kernel.append(sk)
+
+ examples["images"] = images
+ examples["original_sizes"] = original_sizes
+ examples["crop_top_lefts"] = crop_top_lefts
+ examples["kernel"] = kernel
+ examples["kernel2"] = kernel2
+ examples["sinc_kernel"] = sinc_kernel
+
+ return examples
+
+ with accelerator.main_process_first():
+ dataset = dataset.with_transform(preprocess_train)
+
+ return dataset
+
+def collate_fn(examples):
+ images = torch.stack([example["images"] for example in examples])
+ images = images.to(memory_format=torch.contiguous_format).float()
+ kernel = torch.stack([example["kernel"] for example in examples])
+ kernel = kernel.to(memory_format=torch.contiguous_format).float()
+ kernel2 = torch.stack([example["kernel2"] for example in examples])
+ kernel2 = kernel2.to(memory_format=torch.contiguous_format).float()
+ sinc_kernel = torch.stack([example["sinc_kernel"] for example in examples])
+ sinc_kernel = sinc_kernel.to(memory_format=torch.contiguous_format).float()
+ original_sizes = [example["original_sizes"] for example in examples]
+ crop_top_lefts = [example["crop_top_lefts"] for example in examples]
+
+ prompts = []
+ for example in examples:
+ prompts.append(example[args.caption_column]) if args.caption_column in example else prompts.append("")
+
+ return {
+ "images": images,
+ "text": prompts,
+ "kernel": kernel,
+ "kernel2": kernel2,
+ "sinc_kernel": sinc_kernel,
+ "original_sizes": original_sizes,
+ "crop_top_lefts": crop_top_lefts,
+ }
+
+def encode_prompt(prompt_batch, text_encoders, tokenizers, is_train=True):
+ prompt_embeds_list = []
+
+ captions = []
+ for caption in prompt_batch:
+ if isinstance(caption, str):
+ captions.append(caption)
+ elif isinstance(caption, (list, np.ndarray)):
+ # take a random caption if there are multiple
+ captions.append(random.choice(caption) if is_train else caption[0])
+
+ with torch.no_grad():
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
+ text_inputs = tokenizer(
+ captions,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ prompt_embeds = text_encoder(
+ text_input_ids.to(text_encoder.device),
+ output_hidden_states=True,
+ )
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
+ return prompt_embeds, pooled_prompt_embeds
+
+def importance_sampling_fn(t, max_t, alpha):
+ """Importance Sampling Function f(t)"""
+ return 1 / max_t * (1 - alpha * np.cos(np.pi * t / max_t))
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+def tensor_to_pil(images):
+ """
+ Convert image tensor or a batch of image tensors to PIL image(s).
+ """
+ images = (images + 1) / 2
+ images_np = images.detach().cpu().numpy()
+ if images_np.ndim == 4:
+ images_np = np.transpose(images_np, (0, 2, 3, 1))
+ elif images_np.ndim == 3:
+ images_np = np.transpose(images_np, (1, 2, 0))
+ images_np = images_np[None, ...]
+ images_np = (images_np * 255).round().astype("uint8")
+ if images_np.shape[-1] == 1:
+ # special case for grayscale (single channel) images
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np]
+ else:
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np]
+
+ return pil_images
+
+def save_np_to_image(img_np, save_dir):
+ img_np = np.transpose(img_np, (0, 2, 3, 1))
+ img_np = (img_np * 255).astype(np.uint8)
+ img_np = Image.fromarray(img_np[0])
+ img_np.save(save_dir)
+
+
+def seperate_SFT_params_from_unet(unet):
+ params = []
+ non_params = []
+ for name, param in unet.named_parameters():
+ if "SFT" in name:
+ params.append(param)
+ else:
+ non_params.append(param)
+ return params, non_params
+
+
+def seperate_lora_params_from_unet(unet):
+ keys = []
+ frozen_keys = []
+ for name, param in unet.named_parameters():
+ if "lora" in name:
+ keys.append(param)
+ else:
+ frozen_keys.append(param)
+ return keys, frozen_keys
+
+
+def seperate_ip_params_from_unet(unet):
+ ip_params = []
+ non_ip_params = []
+ for name, param in unet.named_parameters():
+ if "encoder_hid_proj." in name or "_ip." in name:
+ ip_params.append(param)
+ elif "attn" in name and "processor" in name:
+ if "ip" in name or "ln" in name:
+ ip_params.append(param)
+ else:
+ non_ip_params.append(param)
+ return ip_params, non_ip_params
+
+
+def seperate_ref_params_from_unet(unet):
+ ip_params = []
+ non_ip_params = []
+ for name, param in unet.named_parameters():
+ if "encoder_hid_proj." in name or "_ip." in name:
+ ip_params.append(param)
+ elif "attn" in name and "processor" in name:
+ if "ip" in name or "ln" in name:
+ ip_params.append(param)
+ elif "extract" in name:
+ ip_params.append(param)
+ else:
+ non_ip_params.append(param)
+ return ip_params, non_ip_params
+
+
+def seperate_ip_modules_from_unet(unet):
+ ip_modules = []
+ non_ip_modules = []
+ for name, module in unet.named_modules():
+ if "encoder_hid_proj" in name or "attn2.processor" in name:
+ ip_modules.append(module)
+ else:
+ non_ip_modules.append(module)
+ return ip_modules, non_ip_modules
+
+
+def seperate_SFT_keys_from_unet(unet):
+ keys = []
+ non_keys = []
+ for name, param in unet.named_parameters():
+ if "SFT" in name:
+ keys.append(name)
+ else:
+ non_keys.append(name)
+ return keys, non_keys
+
+
+def seperate_ip_keys_from_unet(unet):
+ keys = []
+ non_keys = []
+ for name, param in unet.named_parameters():
+ if "encoder_hid_proj." in name or "_ip." in name:
+ keys.append(name)
+ elif "attn" in name and "processor" in name:
+ if "ip" in name or "ln" in name:
+ keys.append(name)
+ else:
+ non_keys.append(name)
+ return keys, non_keys
\ No newline at end of file
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e85c7a43b46fd305c643e1c1b387e7490e5dcc5
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,51 @@
+import torch
+import numpy as np
+from einops import rearrange
+from kornia.geometry.transform.crop2d import warp_affine
+
+from utils.matlab_cp2tform import get_similarity_transform_for_cv2
+from torchvision.transforms import Pad
+
+REFERNCE_FACIAL_POINTS_RELATIVE = np.array([[38.29459953, 51.69630051],
+ [72.53179932, 51.50139999],
+ [56.02519989, 71.73660278],
+ [41.54930115, 92.3655014],
+ [70.72990036, 92.20410156]
+ ]) / 112 # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112
+
+
+@torch.no_grad()
+def detect_face(images: torch.Tensor, mtcnn: torch.nn.Module) -> torch.Tensor:
+ """
+ Detect faces in the images using MTCNN. If no face is detected, use the whole image.
+ """
+ images = rearrange(images, "b c h w -> b h w c")
+ if images.dtype != torch.uint8:
+ images = ((images * 0.5 + 0.5) * 255).type(torch.uint8) # Unnormalize
+
+ _, _, landmarks = mtcnn(images, landmarks=True)
+
+ return landmarks
+
+
+def extract_faces_and_landmarks(images: torch.Tensor, output_size=112, mtcnn: torch.nn.Module = None, refernce_points=REFERNCE_FACIAL_POINTS_RELATIVE):
+ """
+ detect faces in the images and crop them (in a differentiable way) to 112x112 using MTCNN.
+ """
+ images = Pad(200)(images)
+ landmarks_batched = detect_face(images, mtcnn=mtcnn)
+ affine_transformations = []
+ invalid_indices = []
+ for i, landmarks in enumerate(landmarks_batched):
+ if landmarks is None:
+ invalid_indices.append(i)
+ affine_transformations.append(np.eye(2, 3).astype(np.float32))
+ else:
+ affine_transformations.append(get_similarity_transform_for_cv2(landmarks[0].astype(np.float32),
+ refernce_points.astype(np.float32) * output_size))
+ affine_transformations = torch.from_numpy(np.stack(affine_transformations).astype(np.float32)).to(device=images.device, dtype=torch.float32)
+
+ invalid_indices = torch.tensor(invalid_indices).to(device=images.device)
+
+ fp_images = images.to(torch.float32)
+ return warp_affine(fp_images, affine_transformations, dsize=(output_size, output_size)).to(dtype=images.dtype), invalid_indices
\ No newline at end of file
diff --git a/utils/vis_utils.py b/utils/vis_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e25a788a9010d002d622104bf6b7a3e05c375749
--- /dev/null
+++ b/utils/vis_utils.py
@@ -0,0 +1,58 @@
+import textwrap
+from typing import List, Tuple, Optional
+
+import numpy as np
+from PIL import Image, ImageDraw, ImageFont
+
+LINE_WIDTH = 20
+
+
+def add_text_to_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0),
+ min_lines: Optional[int] = None, add_below: bool = True):
+ import textwrap
+ lines = textwrap.wrap(text, width=LINE_WIDTH)
+ if min_lines is not None and len(lines) < min_lines:
+ if add_below:
+ lines += [''] * (min_lines - len(lines))
+ else:
+ lines = [''] * (min_lines - len(lines)) + lines
+ h, w, c = image.shape
+ offset = int(h * .12)
+ img = np.ones((h + offset * len(lines), w, c), dtype=np.uint8) * 255
+ font_size = int(offset * .8)
+
+ try:
+ font = ImageFont.truetype("assets/OpenSans-Regular.ttf", font_size)
+ textsize = font.getbbox(text)
+ y_offset = (offset - textsize[3]) // 2
+ except:
+ font = ImageFont.load_default()
+ y_offset = offset // 2
+
+ if add_below:
+ img[:h] = image
+ else:
+ img[-h:] = image
+ img = Image.fromarray(img)
+ draw = ImageDraw.Draw(img)
+ for i, line in enumerate(lines):
+ line_size = font.getbbox(line)
+ text_x = (w - line_size[2]) // 2
+ if add_below:
+ draw.text((text_x, h + y_offset + offset * i), line, font=font, fill=text_color)
+ else:
+ draw.text((text_x, 0 + y_offset + offset * i), line, font=font, fill=text_color)
+ return np.array(img)
+
+
+def create_table_plot(titles: List[str], images: List[Image.Image], captions: List[str]) -> Image.Image:
+ title_max_lines = np.max([len(textwrap.wrap(text, width=LINE_WIDTH)) for text in titles])
+ caption_max_lines = np.max([len(textwrap.wrap(text, width=LINE_WIDTH)) for text in captions])
+ out_images = []
+ for i in range(len(images)):
+ im = np.array(images[i])
+ im = add_text_to_image(im, titles[i], add_below=False, min_lines=title_max_lines)
+ im = add_text_to_image(im, captions[i], add_below=True, min_lines=caption_max_lines)
+ out_images.append(im)
+ image = Image.fromarray(np.concatenate(out_images, axis=1))
+ return image