fffiloni commited on
Commit
0305a63
1 Parent(s): 1924f8c

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE +201 -0
  3. ORIGINAL_README.md +191 -0
  4. assets/.DS_Store +0 -0
  5. assets/Konan.png +0 -0
  6. assets/Naruto.png +0 -0
  7. assets/cottage.png +0 -0
  8. assets/dog.png +0 -0
  9. assets/lady.png +0 -0
  10. assets/man.png +0 -0
  11. assets/panda.png +0 -0
  12. assets/sculpture.png +3 -0
  13. assets/teaser_figure.png +3 -0
  14. config_files/IR_dataset.yaml +9 -0
  15. config_files/losses.yaml +19 -0
  16. config_files/val_dataset.yaml +7 -0
  17. data/data_config.py +14 -0
  18. data/dataset.py +202 -0
  19. docs/.DS_Store +0 -0
  20. docs/static/.DS_Store +0 -0
  21. environment.yaml +37 -0
  22. gradio_demo/app.py +250 -0
  23. infer.py +381 -0
  24. infer.sh +6 -0
  25. losses/loss_config.py +15 -0
  26. losses/losses.py +465 -0
  27. module/aggregator.py +983 -0
  28. module/attention.py +259 -0
  29. module/diffusers_vae/autoencoder_kl.py +489 -0
  30. module/diffusers_vae/vae.py +985 -0
  31. module/ip_adapter/attention_processor.py +1467 -0
  32. module/ip_adapter/ip_adapter.py +236 -0
  33. module/ip_adapter/resampler.py +158 -0
  34. module/ip_adapter/utils.py +248 -0
  35. module/min_sdxl.py +915 -0
  36. module/unet/unet_2d_ZeroSFT.py +1397 -0
  37. module/unet/unet_2d_ZeroSFT_blocks.py +0 -0
  38. pipelines/sdxl_instantir.py +1740 -0
  39. pipelines/stage1_sdxl_pipeline.py +1283 -0
  40. requirements.txt +14 -0
  41. schedulers/lcm_single_step_scheduler.py +537 -0
  42. train_previewer_lora.py +1712 -0
  43. train_previewer_lora.sh +24 -0
  44. train_stage1_adapter.py +1259 -0
  45. train_stage1_adapter.sh +17 -0
  46. train_stage2_aggregator.py +1698 -0
  47. train_stage2_aggregator.sh +24 -0
  48. utils/degradation_pipeline.py +353 -0
  49. utils/matlab_cp2tform.py +350 -0
  50. utils/parser.py +452 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/sculpture.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/teaser_figure.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ORIGINAL_README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>InstantIR: Blind Image Restoration with</br>Instant Generative Reference</h1>
3
+
4
+ [**Jen-Yuan Huang**](https://jy-joy.github.io)<sup>1&nbsp;2</sup>, [**Haofan Wang**](https://haofanwang.github.io/)<sup>2</sup>, [**Qixun Wang**](https://github.com/wangqixun)<sup>2</sup>, [**Xu Bai**](https://huggingface.co/baymin0220)<sup>2</sup>, Hao Ai<sup>2</sup>, Peng Xing<sup>2</sup>, [**Jen-Tse Huang**](https://penguinnnnn.github.io)<sup>3</sup> <br>
5
+
6
+ <sup>1</sup>Peking University · <sup>2</sup>InstantX Team · <sup>3</sup>The Chinese University of Hong Kong
7
+
8
+ <!-- <sup>*</sup>corresponding authors -->
9
+
10
+ <a href='https://arxiv.org/abs/2410.06551'><img src='https://img.shields.io/badge/arXiv-2410.06551-b31b1b.svg'>
11
+ <a href='https://jy-joy.github.io/InstantIR/'><img src='https://img.shields.io/badge/Project-Website-green'></a>
12
+ <a href='https://huggingface.co/InstantX/InstantIR'><img src='https://img.shields.io/static/v1?label=Model&message=Huggingface&color=orange'></a>
13
+ <!-- [![GitHub](https://img.shields.io/github/stars/InstantID/InstantID?style=social)](https://github.com/InstantID/InstantID) -->
14
+
15
+ <!-- <a href='https://huggingface.co/spaces/InstantX/InstantID'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a>
16
+ [![ModelScope](https://img.shields.io/badge/ModelScope-Studios-blue)](https://modelscope.cn/studios/instantx/InstantID/summary)
17
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/InstantX/InstantID) -->
18
+
19
+ </div>
20
+
21
+ **InstantIR** is a novel single-image restoration model designed to resurrect your damaged images, delivering extrem-quality yet realistic details. You can further boost **InstantIR** performance with additional text prompts, even achieve customized editing!
22
+
23
+
24
+ <!-- >**Abstract**: <br>
25
+ > Handling test-time unknown degradation is the major challenge in Blind Image Restoration (BIR), necessitating high model generalization. An effective strategy is to incorporate prior knowledge, either from human input or generative model. In this paper, we introduce Instant-reference Image Restoration (InstantIR), a novel diffusion-based BIR method which dynamically adjusts generation condition during inference. We first extract a compact representation of the input via a pre-trained vision encoder. At each generation step, this representation is used to decode current diffusion latent and instantiate it in the generative prior. The degraded image is then encoded with this reference, providing robust generation condition. We observe the variance of generative references fluctuate with degradation intensity, which we further leverage as an indicator for developing a sampling algorithm adaptive to input quality. Extensive experiments demonstrate InstantIR achieves state-of-the-art performance and offering outstanding visual quality. Through modulating generative references with textual description, InstantIR can restore extreme degradation and additionally feature creative restoration. -->
26
+
27
+ <img src='assets/teaser_figure.png'>
28
+
29
+ ## 📢 News
30
+ - **11/03/2024** 🔥 We provide a Gradio launching script for InstantIR, you can now deploy it on your local machine!
31
+ - **11/02/2024** 🔥 InstantIR is now compatitble with 🧨 `diffusers`, you can utilize features from this fascinating package!
32
+ - **10/15/2024** 🔥 Code and model released!
33
+
34
+ ## 📝 TODOs:
35
+ - [ ] Launch online demo
36
+ - [x] Remove dependency on local `diffusers`
37
+ - [x] Gradio launching script
38
+
39
+ ## ✨ Usage
40
+ <!-- ### Online Demo
41
+ We provide a Gradio Demo on 🤗, click the button below and have fun with InstantIR! -->
42
+
43
+ ### Quick start
44
+ #### 1. Clone this repo and setting up environment
45
+ ```sh
46
+ git clone https://github.com/JY-Joy/InstantIR.git
47
+ cd InstantIR
48
+ conda create -n instantir python=3.9 -y
49
+ conda activate instantir
50
+ pip install -r requirements.txt
51
+ ```
52
+
53
+ #### 2. Download pre-trained models
54
+
55
+ InstantIR is built on SDXL and DINOv2. You can download them either directly from 🤗 huggingface or using Python package.
56
+
57
+ | 🤗 link | Python command
58
+ | :--- | :----------
59
+ |[SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) | `hf_hub_download(repo_id="stabilityai/stable-diffusion-xl-base-1.0")`
60
+ |[facebook/dinov2-large](https://huggingface.co/facebook/dinov2-large) | `hf_hub_download(repo_id="facebook/dinov2-large")`
61
+ |[InstantX/InstantIR](https://huggingface.co/InstantX/InstantIR) | `hf_hub_download(repo_id="InstantX/InstantIR")`
62
+
63
+ Note: Make sure to import the package first with `from huggingface_hub import hf_hub_download` if you are using Python script.
64
+
65
+ #### 3. Inference
66
+
67
+ You can run InstantIR inference using `infer.sh` with the following arguments specified.
68
+
69
+ ```sh
70
+ infer.sh \
71
+ --sdxl_path <path_to_SDXL> \
72
+ --vision_encoder_path <path_to_DINOv2> \
73
+ --instantir_path <path_to_InstantIR> \
74
+ --test_path <path_to_input> \
75
+ --out_path <path_to_output>
76
+ ```
77
+
78
+ See `infer.py` for more config options.
79
+
80
+ #### 4. Using tips
81
+
82
+ InstantIR is powerful, but with your help it can do better. InstantIR's flexible pipeline makes it tunable to a large extent. Here are some tips we found particularly useful for various cases you may encounter:
83
+ - **Over-smoothing**: reduce `--cfg` to 3.0~5.0. Higher CFG scales can sometimes rigid lines or lack of details.
84
+ - **Low fidelity**: set `--preview_start` to 0.1~0.4 to preserve fidelity from inputs. The previewer can yield misleading references when input latent is too noisy. In such cases, we suggest to disable the previewer at early timesteps.
85
+ - **Local distortions**: set `--creative_start` to 0.6~0.8. This will let InstantIR render freely in the late diffusion process, where the high-frequency details are generated. Smaller `--creative_start` spares more spaces for creative restoration, but will diminish fidelity.
86
+ - **Faster inference**: higher `--preview_start` and lower `--creative_start` can both reduce computational costs and accelerate InstantIR inference.
87
+
88
+ > [!CAUTION]
89
+ > These features are training-free and thus experimental. If you would like to try, we suggest to tune these parameters case-by-case.
90
+
91
+ ### Use InstantIR with diffusers 🧨
92
+
93
+ InstantIR is fully compatible with `diffusers` and is supported by all those powerful features in this package. You can directly load InstantIR via `diffusers` snippet:
94
+
95
+ ```py
96
+ # !pip install diffusers opencv-python transformers accelerate
97
+ import torch
98
+ from PIL import Image
99
+
100
+ from diffusers import DDPMScheduler
101
+ from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
102
+
103
+ from module.ip_adapter.utils import load_adapter_to_pipe
104
+ from pipelines.sdxl_instantir import InstantIRPipeline
105
+
106
+ # suppose you have InstantIR weights under ./models
107
+ instantir_path = f'./models'
108
+
109
+ # load pretrained models
110
+ pipe = InstantIRPipeline.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', torch_dtype=torch.float16)
111
+
112
+ # load adapter
113
+ load_adapter_to_pipe(
114
+ pipe,
115
+ f"{instantir_path}/adapter.pt",
116
+ image_encoder_or_path = 'facebook/dinov2-large',
117
+ )
118
+
119
+ # load previewer lora
120
+ pipe.prepare_previewers(instantir_path)
121
+ pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
122
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
123
+
124
+ # load aggregator weights
125
+ pretrained_state_dict = torch.load(f"{instantir_path}/aggregator.pt")
126
+ pipe.aggregator.load_state_dict(pretrained_state_dict)
127
+
128
+ # send to GPU and fp16
129
+ pipe.to(device='cuda', dtype=torch.float16)
130
+ pipe.aggregator.to(device='cuda', dtype=torch.float16)
131
+ ```
132
+
133
+ Then, you just need to call the `pipe` and InstantIR will handle your image!
134
+
135
+ ```py
136
+ # load a broken image
137
+ low_quality_image = Image.open('./assets/sculpture.png').convert("RGB")
138
+
139
+ # InstantIR restoration
140
+ image = pipe(
141
+ image=low_quality_image,
142
+ previewer_scheduler=lcm_scheduler,
143
+ ).images[0]
144
+ ```
145
+
146
+ ### Deploy local gradio demo
147
+
148
+ We provide a python script to launch a local gradio demo of InstantIR, with basic and some advanced features implemented. Start by running the following command in your terminal:
149
+
150
+ ```sh
151
+ INSTANTIR_PATH=<path_to_InstantIR> python gradio_demo/app.py
152
+ ```
153
+
154
+ Then, visit your local demo via your browser at `http://localhost:7860`.
155
+
156
+
157
+ ## ⚙️ Training
158
+
159
+ ### Prepare data
160
+
161
+ InstantIR is trained on [DIV2K](https://www.kaggle.com/datasets/joe1995/div2k-dataset), [Flickr2K](https://www.kaggle.com/datasets/daehoyang/flickr2k), [LSDIR](https://data.vision.ee.ethz.ch/yawli/index.html) and [FFHQ](https://www.kaggle.com/datasets/rahulbhalley/ffhq-1024x1024). We adopt dataset weighting to balance the distribution. You can config their weights in ```config_files/IR_dataset.yaml```. Download these training sets and put them under a same directory, which will be used in the following training configurations.
162
+
163
+ ### Two-stage training
164
+ As described in our paper, the training of InstantIR is conducted in two stages. We provide corresponding `.sh` training scripts for each stage. Make sure you have the following arguments adapted to your own use case:
165
+
166
+ | Argument | Value
167
+ | :--- | :----------
168
+ | `--pretrained_model_name_or_path` | path to your SDXL folder
169
+ | `--feature_extractor_path` | path to your DINOv2 folder
170
+ | `--train_data_dir` | your training data directory
171
+ | `--output_dir` | path to save model weights
172
+ | `--logging_dir` | path to save logs
173
+ | `<num_of_gpus>` | number of available GPUs
174
+
175
+ Other training hyperparameters we used in our experiments are provided in the corresponding `.sh` scripts. You can tune them according to your own needs.
176
+
177
+ ## 👏 Acknowledgment
178
+ Our work is sponsored by [HuggingFace](https://huggingface.co) and [fal.ai](https://fal.ai).
179
+
180
+ ## 🎓 Citation
181
+
182
+ If InstantIR is helpful to your work, please cite our paper via:
183
+
184
+ ```
185
+ @article{huang2024instantir,
186
+ title={InstantIR: Blind Image Restoration with Instant Generative Reference},
187
+ author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse},
188
+ journal={arXiv preprint arXiv:2410.06551},
189
+ year={2024}
190
+ }
191
+ ```
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/Konan.png ADDED
assets/Naruto.png ADDED
assets/cottage.png ADDED
assets/dog.png ADDED
assets/lady.png ADDED
assets/man.png ADDED
assets/panda.png ADDED
assets/sculpture.png ADDED

Git LFS Details

  • SHA256: 2c4af7c3dc545d2f48b0ac2afef69bd7b1f0489ced7ea452d92f69ff5a9d4019
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
assets/teaser_figure.png ADDED

Git LFS Details

  • SHA256: a9c7e8e59af17516d11e21c5bc56b48824a3875c81e4afb181a5c3facc217d08
  • Pointer size: 133 Bytes
  • Size of remote file: 16.9 MB
config_files/IR_dataset.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+ - dataset_folder: 'ffhq'
3
+ dataset_weight: 0.1
4
+ - dataset_folder: 'DIV2K'
5
+ dataset_weight: 0.3
6
+ - dataset_folder: 'LSDIR'
7
+ dataset_weight: 0.3
8
+ - dataset_folder: 'Flickr2K'
9
+ dataset_weight: 0.1
config_files/losses.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusion_losses:
2
+ - name: L2Loss
3
+ weight: 1
4
+ lcm_losses:
5
+ - name: HuberLoss
6
+ weight: 1
7
+ # - name: DINOLoss
8
+ # weight: 1e-3
9
+ # - name: L2Loss
10
+ # weight: 5e-2
11
+ # - name: LPIPSLoss
12
+ # weight: 1e-3
13
+ # - name: DreamSIMLoss
14
+ # weight: 1e-3
15
+ # - name: IDLoss
16
+ # weight: 1e-3
17
+ # visualize_every_k: 50
18
+ # init_params:
19
+ # pretrained_arcface_path: /home/dcor/orlichter/consistency_encoder_private/pretrained_models/model_ir_se50.pth
config_files/val_dataset.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ datasets:
2
+ - dataset_folder: 'ffhq'
3
+ dataset_weight: 0.1
4
+ - dataset_folder: 'DIV2K'
5
+ dataset_weight: 0.45
6
+ - dataset_folder: 'LSDIR'
7
+ dataset_weight: 0.45
data/data_config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional, List
3
+
4
+
5
+ @dataclass
6
+ class SingleDataConfig:
7
+ dataset_folder: str
8
+ imagefolder: bool = True
9
+ dataset_weight: float = 1.0 # Not used yet
10
+
11
+ @dataclass
12
+ class DataConfig:
13
+ datasets: List[SingleDataConfig]
14
+ val_dataset: Optional[SingleDataConfig] = None
data/dataset.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ from PIL import Image
5
+ from PIL.ImageOps import exif_transpose
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+ import json
9
+ import random
10
+ from facenet_pytorch import MTCNN
11
+ import torch
12
+
13
+ from utils.utils import extract_faces_and_landmarks, REFERNCE_FACIAL_POINTS_RELATIVE
14
+
15
+ def load_image(image_path: str) -> Image:
16
+ image = Image.open(image_path)
17
+ image = exif_transpose(image)
18
+ if not image.mode == "RGB":
19
+ image = image.convert("RGB")
20
+ return image
21
+
22
+
23
+ class ImageDataset(Dataset):
24
+ """
25
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
26
+ It pre-processes the images.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ instance_data_root,
32
+ instance_prompt,
33
+ metadata_path: Optional[str] = None,
34
+ prompt_in_filename=False,
35
+ use_only_vanilla_for_encoder=False,
36
+ concept_placeholder='a face',
37
+ size=1024,
38
+ center_crop=False,
39
+ aug_images=False,
40
+ use_only_decoder_prompts=False,
41
+ crop_head_for_encoder_image=False,
42
+ random_target_prob=0.0,
43
+ ):
44
+ self.mtcnn = MTCNN(device='cuda:0')
45
+ self.mtcnn.forward = self.mtcnn.detect
46
+ resize_factor = 1.3
47
+ self.resized_reference_points = REFERNCE_FACIAL_POINTS_RELATIVE / resize_factor + (resize_factor - 1) / (2 * resize_factor)
48
+ self.size = size
49
+ self.center_crop = center_crop
50
+ self.concept_placeholder = concept_placeholder
51
+ self.prompt_in_filename = prompt_in_filename
52
+ self.aug_images = aug_images
53
+
54
+ self.instance_prompt = instance_prompt
55
+ self.custom_instance_prompts = None
56
+ self.name_to_label = None
57
+ self.crop_head_for_encoder_image = crop_head_for_encoder_image
58
+ self.random_target_prob = random_target_prob
59
+
60
+ self.use_only_decoder_prompts = use_only_decoder_prompts
61
+
62
+ self.instance_data_root = Path(instance_data_root)
63
+
64
+ if not self.instance_data_root.exists():
65
+ raise ValueError(f"Instance images root {self.instance_data_root} doesn't exist.")
66
+
67
+ if metadata_path is not None:
68
+ with open(metadata_path, 'r') as f:
69
+ self.name_to_label = json.load(f) # dict of filename: label
70
+ # Create a reversed mapping
71
+ self.label_to_names = {}
72
+ for name, label in self.name_to_label.items():
73
+ if use_only_vanilla_for_encoder and 'vanilla' not in name:
74
+ continue
75
+ if label not in self.label_to_names:
76
+ self.label_to_names[label] = []
77
+ self.label_to_names[label].append(name)
78
+ self.all_paths = [self.instance_data_root / filename for filename in self.name_to_label.keys()]
79
+
80
+ # Verify all paths exist
81
+ n_all_paths = len(self.all_paths)
82
+ self.all_paths = [path for path in self.all_paths if path.exists()]
83
+ print(f'Found {len(self.all_paths)} out of {n_all_paths} paths.')
84
+ else:
85
+ self.all_paths = [path for path in list(Path(instance_data_root).glob('**/*')) if
86
+ path.suffix.lower() in [".png", ".jpg", ".jpeg"]]
87
+ # Sort by name so that order for validation remains the same across runs
88
+ self.all_paths = sorted(self.all_paths, key=lambda x: x.stem)
89
+
90
+ self.custom_instance_prompts = None
91
+
92
+ self._length = len(self.all_paths)
93
+
94
+ self.class_data_root = None
95
+
96
+ self.image_transforms = transforms.Compose(
97
+ [
98
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
99
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize([0.5], [0.5]),
102
+ ]
103
+ )
104
+
105
+ if self.prompt_in_filename:
106
+ self.prompts_set = set([self._path_to_prompt(path) for path in self.all_paths])
107
+ else:
108
+ self.prompts_set = set([self.instance_prompt])
109
+
110
+ if self.aug_images:
111
+ self.aug_transforms = transforms.Compose(
112
+ [
113
+ transforms.RandomResizedCrop(size, scale=(0.8, 1.0), ratio=(1.0, 1.0)),
114
+ transforms.RandomHorizontalFlip(p=0.5)
115
+ ]
116
+ )
117
+
118
+ def __len__(self):
119
+ return self._length
120
+
121
+ def _path_to_prompt(self, path):
122
+ # Remove the extension and seed
123
+ split_path = path.stem.split('_')
124
+ while split_path[-1].isnumeric():
125
+ split_path = split_path[:-1]
126
+
127
+ prompt = ' '.join(split_path)
128
+ # Replace placeholder in prompt with training placeholder
129
+ prompt = prompt.replace('conceptname', self.concept_placeholder)
130
+ return prompt
131
+
132
+ def __getitem__(self, index):
133
+ example = {}
134
+ instance_path = self.all_paths[index]
135
+ instance_image = load_image(instance_path)
136
+ example["instance_images"] = self.image_transforms(instance_image)
137
+ if self.prompt_in_filename:
138
+ example["instance_prompt"] = self._path_to_prompt(instance_path)
139
+ else:
140
+ example["instance_prompt"] = self.instance_prompt
141
+
142
+ if self.name_to_label is None:
143
+ # If no labels, simply take the same image but with different augmentation
144
+ example["encoder_images"] = self.aug_transforms(example["instance_images"]) if self.aug_images else example["instance_images"]
145
+ example["encoder_prompt"] = example["instance_prompt"]
146
+ else:
147
+ # Randomly select another image with the same label
148
+ instance_name = str(instance_path.relative_to(self.instance_data_root))
149
+ instance_label = self.name_to_label[instance_name]
150
+ label_set = set(self.label_to_names[instance_label])
151
+ if len(label_set) == 1:
152
+ # We are not supposed to have only one image per label, but just in case
153
+ encoder_image_name = instance_name
154
+ print(f'WARNING: Only one image for label {instance_label}.')
155
+ else:
156
+ encoder_image_name = random.choice(list(label_set - {instance_name}))
157
+ encoder_image = load_image(self.instance_data_root / encoder_image_name)
158
+ example["encoder_images"] = self.image_transforms(encoder_image)
159
+
160
+ if self.prompt_in_filename:
161
+ example["encoder_prompt"] = self._path_to_prompt(self.instance_data_root / encoder_image_name)
162
+ else:
163
+ example["encoder_prompt"] = self.instance_prompt
164
+
165
+ if self.crop_head_for_encoder_image:
166
+ example["encoder_images"] = extract_faces_and_landmarks(example["encoder_images"][None], self.size, self.mtcnn, self.resized_reference_points)[0][0]
167
+ example["encoder_prompt"] = example["encoder_prompt"].format(placeholder="<ph>")
168
+ example["instance_prompt"] = example["instance_prompt"].format(placeholder="<s*>")
169
+
170
+ if random.random() < self.random_target_prob:
171
+ random_path = random.choice(self.all_paths)
172
+
173
+ random_image = load_image(random_path)
174
+ example["instance_images"] = self.image_transforms(random_image)
175
+ if self.prompt_in_filename:
176
+ example["instance_prompt"] = self._path_to_prompt(random_path)
177
+
178
+
179
+ if self.use_only_decoder_prompts:
180
+ example["encoder_prompt"] = example["instance_prompt"]
181
+
182
+ return example
183
+
184
+
185
+ def collate_fn(examples, with_prior_preservation=False):
186
+ pixel_values = [example["instance_images"] for example in examples]
187
+ encoder_pixel_values = [example["encoder_images"] for example in examples]
188
+ prompts = [example["instance_prompt"] for example in examples]
189
+ encoder_prompts = [example["encoder_prompt"] for example in examples]
190
+
191
+ if with_prior_preservation:
192
+ raise NotImplementedError("Prior preservation not implemented.")
193
+
194
+ pixel_values = torch.stack(pixel_values)
195
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
196
+
197
+ encoder_pixel_values = torch.stack(encoder_pixel_values)
198
+ encoder_pixel_values = encoder_pixel_values.to(memory_format=torch.contiguous_format).float()
199
+
200
+ batch = {"pixel_values": pixel_values, "encoder_pixel_values": encoder_pixel_values,
201
+ "prompts": prompts, "encoder_prompts": encoder_prompts}
202
+ return batch
docs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
docs/static/.DS_Store ADDED
Binary file (6.15 kB). View file
 
environment.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: instantir
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - numpy
9
+ - pandas
10
+ - pillow
11
+ - pip
12
+ - python=3.9.15
13
+ - pytorch=2.2.2
14
+ - pytorch-lightning=1.6.5
15
+ - pytorch-cuda=12.1
16
+ - setuptools
17
+ - torchaudio=2.2.2
18
+ - torchmetrics
19
+ - torchvision=0.17.2
20
+ - tqdm
21
+ - pip:
22
+ - accelerate==0.25.0
23
+ - diffusers==0.24.0
24
+ - einops
25
+ - open-clip-torch
26
+ - opencv-python==4.8.1.78
27
+ - tokenizers
28
+ - transformers==4.36.2
29
+ - kornia
30
+ - facenet_pytorch
31
+ - lpips
32
+ - dreamsim
33
+ - pyrallis
34
+ - wandb
35
+ - insightface
36
+ - onnxruntime==1.17.0
37
+ - -e git+https://github.com/openai/CLIP.git@main#egg=clip
gradio_demo/app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
4
+
5
+ import torch
6
+ import numpy as np
7
+ import gradio as gr
8
+ from PIL import Image
9
+
10
+ from diffusers import DDPMScheduler
11
+ from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
12
+
13
+ from module.ip_adapter.utils import load_adapter_to_pipe
14
+ from pipelines.sdxl_instantir import InstantIRPipeline
15
+
16
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
17
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
18
+
19
+ w, h = input_image.size
20
+ if size is not None:
21
+ w_resize_new, h_resize_new = size
22
+ else:
23
+ # ratio = min_side / min(h, w)
24
+ # w, h = round(ratio*w), round(ratio*h)
25
+ ratio = max_side / max(h, w)
26
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
27
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
28
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
29
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
30
+
31
+ if pad_to_max_side:
32
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
33
+ offset_x = (max_side - w_resize_new) // 2
34
+ offset_y = (max_side - h_resize_new) // 2
35
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
36
+ input_image = Image.fromarray(res)
37
+ return input_image
38
+
39
+ instantir_path = os.environ['INSTANTIR_PATH']
40
+
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+ sdxl_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
43
+ dinov2_repo_id = "facebook/dinov2-large"
44
+ lcm_repo_id = "latent-consistency/lcm-lora-sdxl"
45
+
46
+ if torch.cuda.is_available():
47
+ torch_dtype = torch.float16
48
+ else:
49
+ torch_dtype = torch.float32
50
+
51
+ # Load pretrained models.
52
+ print("Initializing pipeline...")
53
+ pipe = InstantIRPipeline.from_pretrained(
54
+ sdxl_repo_id,
55
+ torch_dtype=torch_dtype,
56
+ )
57
+
58
+ # Image prompt projector.
59
+ print("Loading LQ-Adapter...")
60
+ load_adapter_to_pipe(
61
+ pipe,
62
+ f"{instantir_path}/adapter.pt",
63
+ dinov2_repo_id,
64
+ )
65
+
66
+ # Prepare previewer
67
+ lora_alpha = pipe.prepare_previewers(instantir_path)
68
+ print(f"use lora alpha {lora_alpha}")
69
+ lora_alpha = pipe.prepare_previewers(lcm_repo_id, use_lcm=True)
70
+ print(f"use lora alpha {lora_alpha}")
71
+ pipe.to(device=device, dtype=torch_dtype)
72
+ pipe.scheduler = DDPMScheduler.from_pretrained(sdxl_repo_id, subfolder="scheduler")
73
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
74
+
75
+ # Load weights.
76
+ print("Loading checkpoint...")
77
+ aggregator_state_dict = torch.load(
78
+ f"{instantir_path}/aggregator.pt",
79
+ map_location="cpu"
80
+ )
81
+ pipe.aggregator.load_state_dict(aggregator_state_dict, strict=True)
82
+ pipe.aggregator.to(device=device, dtype=torch_dtype)
83
+
84
+ MAX_SEED = np.iinfo(np.int32).max
85
+ MAX_IMAGE_SIZE = 1024
86
+
87
+ PROMPT = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
88
+ ultra HD, extreme meticulous detailing, skin pore detailing, \
89
+ hyper sharpness, perfect without deformations, \
90
+ taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. "
91
+
92
+ NEG_PROMPT = "blurry, out of focus, unclear, depth of field, over-smooth, \
93
+ sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
94
+ dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
95
+ watermark, signature, jpeg artifacts, deformed, lowres"
96
+
97
+ def unpack_pipe_out(preview_row, index):
98
+ return preview_row[index][0]
99
+
100
+ def dynamic_preview_slider(sampling_steps):
101
+ print(sampling_steps)
102
+ return gr.Slider(label="Restoration Previews", value=sampling_steps-1, minimum=0, maximum=sampling_steps-1, step=1)
103
+
104
+ def dynamic_guidance_slider(sampling_steps):
105
+ return gr.Slider(label="Start Free Rendering", value=sampling_steps, minimum=0, maximum=sampling_steps, step=1)
106
+
107
+ def show_final_preview(preview_row):
108
+ return preview_row[-1][0]
109
+
110
+ # @spaces.GPU #[uncomment to use ZeroGPU]
111
+ @torch.no_grad()
112
+ def instantir_restore(
113
+ lq, prompt="", steps=30, cfg_scale=7.0, guidance_end=1.0,
114
+ creative_restoration=False, seed=3407, height=1024, width=1024, preview_start=0.0):
115
+ if creative_restoration:
116
+ if "lcm" not in pipe.unet.active_adapters():
117
+ pipe.unet.set_adapter('lcm')
118
+ else:
119
+ if "previewer" not in pipe.unet.active_adapters():
120
+ pipe.unet.set_adapter('previewer')
121
+
122
+ if isinstance(guidance_end, int):
123
+ guidance_end = guidance_end / steps
124
+ elif guidance_end > 1.0:
125
+ guidance_end = guidance_end / steps
126
+ if isinstance(preview_start, int):
127
+ preview_start = preview_start / steps
128
+ elif preview_start > 1.0:
129
+ preview_start = preview_start / steps
130
+ lq = [resize_img(lq.convert("RGB"), size=(width, height))]
131
+ generator = torch.Generator(device=device).manual_seed(seed)
132
+ timesteps = [
133
+ i * (1000//steps) + pipe.scheduler.config.steps_offset for i in range(0, steps)
134
+ ]
135
+ timesteps = timesteps[::-1]
136
+
137
+ prompt = PROMPT if len(prompt)==0 else prompt
138
+ neg_prompt = NEG_PROMPT
139
+
140
+ out = pipe(
141
+ prompt=[prompt]*len(lq),
142
+ image=lq,
143
+ num_inference_steps=steps,
144
+ generator=generator,
145
+ timesteps=timesteps,
146
+ negative_prompt=[neg_prompt]*len(lq),
147
+ guidance_scale=cfg_scale,
148
+ control_guidance_end=guidance_end,
149
+ preview_start=preview_start,
150
+ previewer_scheduler=lcm_scheduler,
151
+ return_dict=False,
152
+ save_preview_row=True,
153
+ )
154
+ for i, preview_img in enumerate(out[1]):
155
+ preview_img.append(f"preview_{i}")
156
+ return out[0][0], out[1]
157
+
158
+ examples = [
159
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
160
+ "An astronaut riding a green horse",
161
+ "A delicious ceviche cheesecake slice",
162
+ ]
163
+
164
+ css="""
165
+ #col-container {
166
+ margin: 0 auto;
167
+ max-width: 640px;
168
+ }
169
+ """
170
+
171
+ with gr.Blocks() as demo:
172
+ gr.Markdown(
173
+ """
174
+ # InstantIR: Blind Image Restoration with Instant Generative Reference.
175
+
176
+ ### **Official 🤗 Gradio demo of [InstantIR](https://arxiv.org/abs/2410.06551).**
177
+ ### **InstantIR can not only help you restore your broken image, but also capable of imaginative re-creation following your text prompts. See advance usage for more details!**
178
+ ## Basic usage: revitalize your image
179
+ 1. Upload an image you want to restore;
180
+ 2. Optionally, tune the `Steps` `CFG Scale` parameters. Typically higher steps lead to better results, but less than 50 is recommended for efficiency;
181
+ 3. Click `InstantIR magic!`.
182
+ """)
183
+ with gr.Row():
184
+ lq_img = gr.Image(label="Low-quality image", type="pil")
185
+ with gr.Column():
186
+ with gr.Row():
187
+ steps = gr.Number(label="Steps", value=30, step=1)
188
+ cfg_scale = gr.Number(label="CFG Scale", value=7.0, step=0.1)
189
+ with gr.Row():
190
+ height = gr.Number(label="Height", value=1024, step=1)
191
+ weight = gr.Number(label="Weight", value=1024, step=1)
192
+ seed = gr.Number(label="Seed", value=42, step=1)
193
+ # guidance_start = gr.Slider(label="Guidance Start", value=1.0, minimum=0.0, maximum=1.0, step=0.05)
194
+ guidance_end = gr.Slider(label="Start Free Rendering", value=30, minimum=0, maximum=30, step=1)
195
+ preview_start = gr.Slider(label="Preview Start", value=0, minimum=0, maximum=30, step=1)
196
+ prompt = gr.Textbox(label="Restoration prompts (Optional)", placeholder="")
197
+ mode = gr.Checkbox(label="Creative Restoration", value=False)
198
+ with gr.Row():
199
+ with gr.Row():
200
+ restore_btn = gr.Button("InstantIR magic!")
201
+ clear_btn = gr.ClearButton()
202
+ index = gr.Slider(label="Restoration Previews", value=29, minimum=0, maximum=29, step=1)
203
+ with gr.Row():
204
+ output = gr.Image(label="InstantIR restored", type="pil")
205
+ preview = gr.Image(label="Preview", type="pil")
206
+ pipe_out = gr.Gallery(visible=False)
207
+ clear_btn.add([lq_img, output, preview])
208
+ restore_btn.click(
209
+ instantir_restore, inputs=[
210
+ lq_img, prompt, steps, cfg_scale, guidance_end,
211
+ mode, seed, height, weight, preview_start,
212
+ ],
213
+ outputs=[output, pipe_out], api_name="InstantIR"
214
+ )
215
+ steps.change(dynamic_guidance_slider, inputs=steps, outputs=guidance_end)
216
+ output.change(dynamic_preview_slider, inputs=steps, outputs=index)
217
+ index.release(unpack_pipe_out, inputs=[pipe_out, index], outputs=preview)
218
+ output.change(show_final_preview, inputs=pipe_out, outputs=preview)
219
+ gr.Markdown(
220
+ """
221
+ ## Advance usage:
222
+ ### Browse restoration variants:
223
+ 1. After InstantIR processing, drag the `Restoration Previews` slider to explore other in-progress versions;
224
+ 2. If you like one of them, set the `Start Free Rendering` slider to the same value to get a more refined result.
225
+ ### Creative restoration:
226
+ 1. Check the `Creative Restoration` checkbox;
227
+ 2. Input your text prompts in the `Restoration prompts` textbox;
228
+ 3. Set `Start Free Rendering` slider to a medium value (around half of the `steps`) to provide adequate room for InstantIR creation.
229
+
230
+ ## Examples
231
+ Here are some examplar usage of InstantIR:
232
+ """)
233
+ # examples = gr.Gallery(label="Examples")
234
+
235
+ gr.Markdown(
236
+ """
237
+ ## Citation
238
+ If InstantIR is helpful to your work, please cite our paper via:
239
+
240
+ ```
241
+ @article{huang2024instantir,
242
+ title={InstantIR: Blind Image Restoration with Instant Generative Reference},
243
+ author={Huang, Jen-Yuan and Wang, Haofan and Wang, Qixun and Bai, Xu and Ai, Hao and Xing, Peng and Huang, Jen-Tse},
244
+ journal={arXiv preprint arXiv:2410.06551},
245
+ year={2024}
246
+ }
247
+ ```
248
+ """)
249
+
250
+ demo.queue().launch()
infer.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import numpy as np
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
8
+
9
+ from diffusers import DDPMScheduler
10
+
11
+ from module.ip_adapter.utils import load_adapter_to_pipe
12
+ from pipelines.sdxl_instantir import InstantIRPipeline
13
+
14
+
15
+ def name_unet_submodules(unet):
16
+ def recursive_find_module(name, module, end=False):
17
+ if end:
18
+ for sub_name, sub_module in module.named_children():
19
+ sub_module.full_name = f"{name}.{sub_name}"
20
+ return
21
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
22
+ elif "resnets" in name: return
23
+ for sub_name, sub_module in module.named_children():
24
+ end = True if sub_name == "transformer_blocks" else False
25
+ recursive_find_module(f"{name}.{sub_name}", sub_module, end)
26
+
27
+ for name, module in unet.named_children():
28
+ recursive_find_module(name, module)
29
+
30
+
31
+ def resize_img(input_image, max_side=1280, min_side=1024, size=None,
32
+ pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
33
+
34
+ w, h = input_image.size
35
+ if size is not None:
36
+ w_resize_new, h_resize_new = size
37
+ else:
38
+ # ratio = min_side / min(h, w)
39
+ # w, h = round(ratio*w), round(ratio*h)
40
+ ratio = max_side / max(h, w)
41
+ input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
42
+ w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
43
+ h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
44
+ input_image = input_image.resize([w_resize_new, h_resize_new], mode)
45
+
46
+ if pad_to_max_side:
47
+ res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
48
+ offset_x = (max_side - w_resize_new) // 2
49
+ offset_y = (max_side - h_resize_new) // 2
50
+ res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
51
+ input_image = Image.fromarray(res)
52
+ return input_image
53
+
54
+
55
+ def tensor_to_pil(images):
56
+ """
57
+ Convert image tensor or a batch of image tensors to PIL image(s).
58
+ """
59
+ images = images.clamp(0, 1)
60
+ images_np = images.detach().cpu().numpy()
61
+ if images_np.ndim == 4:
62
+ images_np = np.transpose(images_np, (0, 2, 3, 1))
63
+ elif images_np.ndim == 3:
64
+ images_np = np.transpose(images_np, (1, 2, 0))
65
+ images_np = images_np[None, ...]
66
+ images_np = (images_np * 255).round().astype("uint8")
67
+ if images_np.shape[-1] == 1:
68
+ # special case for grayscale (single channel) images
69
+ pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images_np]
70
+ else:
71
+ pil_images = [Image.fromarray(image[:, :, :3]) for image in images_np]
72
+
73
+ return pil_images
74
+
75
+
76
+ def calc_mean_std(feat, eps=1e-5):
77
+ """Calculate mean and std for adaptive_instance_normalization.
78
+ Args:
79
+ feat (Tensor): 4D tensor.
80
+ eps (float): A small value added to the variance to avoid
81
+ divide-by-zero. Default: 1e-5.
82
+ """
83
+ size = feat.size()
84
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
85
+ b, c = size[:2]
86
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
87
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
88
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
89
+ return feat_mean, feat_std
90
+
91
+
92
+ def adaptive_instance_normalization(content_feat, style_feat):
93
+ size = content_feat.size()
94
+ style_mean, style_std = calc_mean_std(style_feat)
95
+ content_mean, content_std = calc_mean_std(content_feat)
96
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
97
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
98
+
99
+
100
+ def main(args, device):
101
+
102
+ # Load pretrained models.
103
+ pipe = InstantIRPipeline.from_pretrained(
104
+ args.sdxl_path,
105
+ torch_dtype=torch.float16,
106
+ )
107
+
108
+ # Image prompt projector.
109
+ print("Loading LQ-Adapter...")
110
+ load_adapter_to_pipe(
111
+ pipe,
112
+ args.adapter_model_path if args.adapter_model_path is not None else os.path.join(args.instantir_path, 'adapter.pt'),
113
+ args.vision_encoder_path,
114
+ use_clip_encoder=args.use_clip_encoder,
115
+ )
116
+
117
+ # Prepare previewer
118
+ previewer_lora_path = args.previewer_lora_path if args.previewer_lora_path is not None else args.instantir_path
119
+ if previewer_lora_path is not None:
120
+ lora_alpha = pipe.prepare_previewers(previewer_lora_path)
121
+ print(f"use lora alpha {lora_alpha}")
122
+ pipe.to(device=device, dtype=torch.float16)
123
+ pipe.scheduler = DDPMScheduler.from_pretrained(args.sdxl_path, subfolder="scheduler")
124
+ lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
125
+
126
+ # Load weights.
127
+ print("Loading checkpoint...")
128
+ pretrained_state_dict = torch.load(os.path.join(args.instantir_path, "aggregator.pt"), map_location="cpu")
129
+ pipe.aggregator.load_state_dict(pretrained_state_dict)
130
+ pipe.aggregator.to(device, dtype=torch.float16)
131
+
132
+ #################### Restoration ####################
133
+
134
+ post_fix = f"_{args.post_fix}" if args.post_fix else ""
135
+ os.makedirs(f"{args.out_path}/{post_fix}", exist_ok=True)
136
+
137
+ processed_imgs = os.listdir(os.path.join(args.out_path, post_fix))
138
+ lq_files = []
139
+ lq_batch = []
140
+ if os.path.isfile(args.test_path):
141
+ all_inputs = [args.test_path.split("/")[-1]]
142
+ else:
143
+ all_inputs = os.listdir(args.test_path)
144
+ all_inputs.sort()
145
+ for file in all_inputs:
146
+ if file in processed_imgs:
147
+ print(f"Skip {file}")
148
+ continue
149
+ lq_batch.append(f"{file}")
150
+ if len(lq_batch) == args.batch_size:
151
+ lq_files.append(lq_batch)
152
+ lq_batch = []
153
+
154
+ if len(lq_batch) > 0:
155
+ lq_files.append(lq_batch)
156
+
157
+ for lq_batch in lq_files:
158
+ generator = torch.Generator(device=device).manual_seed(args.seed)
159
+ pil_lqs = [Image.open(os.path.join(args.test_path, file)) for file in lq_batch]
160
+ if args.width is None or args.height is None:
161
+ lq = [resize_img(pil_lq.convert("RGB"), size=None) for pil_lq in pil_lqs]
162
+ else:
163
+ lq = [resize_img(pil_lq.convert("RGB"), size=(args.width, args.height)) for pil_lq in pil_lqs]
164
+ timesteps = None
165
+ if args.denoising_start < 1000:
166
+ timesteps = [
167
+ i * (args.denoising_start//args.num_inference_steps) + pipe.scheduler.config.steps_offset for i in range(0, args.num_inference_steps)
168
+ ]
169
+ timesteps = timesteps[::-1]
170
+ pipe.scheduler.set_timesteps(args.num_inference_steps, device)
171
+ timesteps = pipe.scheduler.timesteps
172
+ if args.prompt is None or len(args.prompt) == 0:
173
+ prompt = "Photorealistic, highly detailed, hyper detailed photo - realistic maximum detail, 32k, \
174
+ ultra HD, extreme meticulous detailing, skin pore detailing, \
175
+ hyper sharpness, perfect without deformations, \
176
+ taken using a Canon EOS R camera, Cinematic, High Contrast, Color Grading. "
177
+ else:
178
+ prompt = args.prompt
179
+ if not isinstance(prompt, list):
180
+ prompt = [prompt]
181
+ prompt = prompt*len(lq)
182
+ if args.neg_prompt is None or len(args.neg_prompt) == 0:
183
+ neg_prompt = "blurry, out of focus, unclear, depth of field, over-smooth, \
184
+ sketch, oil painting, cartoon, CG Style, 3D render, unreal engine, \
185
+ dirty, messy, worst quality, low quality, frames, painting, illustration, drawing, art, \
186
+ watermark, signature, jpeg artifacts, deformed, lowres"
187
+ else:
188
+ neg_prompt = args.neg_prompt
189
+ if not isinstance(neg_prompt, list):
190
+ neg_prompt = [neg_prompt]
191
+ neg_prompt = neg_prompt*len(lq)
192
+ image = pipe(
193
+ prompt=prompt,
194
+ image=lq,
195
+ num_inference_steps=args.num_inference_steps,
196
+ generator=generator,
197
+ timesteps=timesteps,
198
+ negative_prompt=neg_prompt,
199
+ guidance_scale=args.cfg,
200
+ previewer_scheduler=lcm_scheduler,
201
+ preview_start=args.preview_start,
202
+ control_guidance_end=args.creative_start,
203
+ ).images
204
+
205
+ if args.save_preview_row:
206
+ for i, lcm_image in enumerate(image[1]):
207
+ lcm_image.save(f"./lcm/{i}.png")
208
+ for i, rec_image in enumerate(image):
209
+ rec_image.save(f"{args.out_path}/{post_fix}/{lq_batch[i]}")
210
+
211
+
212
+ if __name__ == "__main__":
213
+ parser = argparse.ArgumentParser(description="InstantIR pipeline")
214
+ parser.add_argument(
215
+ "--sdxl_path",
216
+ type=str,
217
+ default=None,
218
+ required=True,
219
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
220
+ )
221
+ parser.add_argument(
222
+ "--previewer_lora_path",
223
+ type=str,
224
+ default=None,
225
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
226
+ )
227
+ parser.add_argument(
228
+ "--pretrained_vae_model_name_or_path",
229
+ type=str,
230
+ default=None,
231
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
232
+ )
233
+ parser.add_argument(
234
+ "--instantir_path",
235
+ type=str,
236
+ default=None,
237
+ required=True,
238
+ help="Path to pretrained instantir model.",
239
+ )
240
+ parser.add_argument(
241
+ "--vision_encoder_path",
242
+ type=str,
243
+ default='/share/huangrenyuan/model_zoo/vis_backbone/dinov2_large',
244
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
245
+ )
246
+ parser.add_argument(
247
+ "--adapter_model_path",
248
+ type=str,
249
+ default=None,
250
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
251
+ )
252
+ parser.add_argument(
253
+ "--adapter_tokens",
254
+ type=int,
255
+ default=64,
256
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
257
+ )
258
+ parser.add_argument(
259
+ "--use_clip_encoder",
260
+ action="store_true",
261
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
262
+ )
263
+ parser.add_argument(
264
+ "--denoising_start",
265
+ type=int,
266
+ default=1000,
267
+ help="Diffusion start timestep."
268
+ )
269
+ parser.add_argument(
270
+ "--num_inference_steps",
271
+ type=int,
272
+ default=30,
273
+ help="Diffusion steps."
274
+ )
275
+ parser.add_argument(
276
+ "--creative_start",
277
+ type=float,
278
+ default=1.0,
279
+ help="Proportion of timesteps for creative restoration. 1.0 means no creative restoration while 0.0 means completely free rendering."
280
+ )
281
+ parser.add_argument(
282
+ "--preview_start",
283
+ type=float,
284
+ default=0.0,
285
+ help="Proportion of timesteps to stop previewing at the begining to enhance fidelity to input."
286
+ )
287
+ parser.add_argument(
288
+ "--resolution",
289
+ type=int,
290
+ default=1024,
291
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
292
+ )
293
+ parser.add_argument(
294
+ "--batch_size",
295
+ type=int,
296
+ default=6,
297
+ help="Test batch size."
298
+ )
299
+ parser.add_argument(
300
+ "--width",
301
+ type=int,
302
+ default=None,
303
+ help="Output image width."
304
+ )
305
+ parser.add_argument(
306
+ "--height",
307
+ type=int,
308
+ default=None,
309
+ help="Output image height."
310
+ )
311
+ parser.add_argument(
312
+ "--cfg",
313
+ type=float,
314
+ default=7.0,
315
+ help="Scale of Classifier-Free-Guidance (CFG).",
316
+ )
317
+ parser.add_argument(
318
+ "--post_fix",
319
+ type=str,
320
+ default=None,
321
+ help="Subfolder name for restoration output under the output directory.",
322
+ )
323
+ parser.add_argument(
324
+ "--variant",
325
+ type=str,
326
+ default='fp16',
327
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
328
+ )
329
+ parser.add_argument(
330
+ "--revision",
331
+ type=str,
332
+ default=None,
333
+ required=False,
334
+ help="Revision of pretrained model identifier from huggingface.co/models.",
335
+ )
336
+ parser.add_argument(
337
+ "--save_preview_row",
338
+ action="store_true",
339
+ help="Whether or not to save the intermediate lcm outputs.",
340
+ )
341
+ parser.add_argument(
342
+ "--prompt",
343
+ type=str,
344
+ default='',
345
+ nargs="+",
346
+ help=(
347
+ "A set of prompts for creative restoration. Provide either a matching number of test images,"
348
+ " or a single prompt to be used with all inputs."
349
+ ),
350
+ )
351
+ parser.add_argument(
352
+ "--neg_prompt",
353
+ type=str,
354
+ default='',
355
+ nargs="+",
356
+ help=(
357
+ "A set of negative prompts for creative restoration. Provide either a matching number of test images,"
358
+ " or a single negative prompt to be used with all inputs."
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--test_path",
363
+ type=str,
364
+ default=None,
365
+ required=True,
366
+ help="Test directory.",
367
+ )
368
+ parser.add_argument(
369
+ "--out_path",
370
+ type=str,
371
+ default="./output",
372
+ help="Output directory.",
373
+ )
374
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
375
+ args = parser.parse_args()
376
+ args.height = args.height or args.width
377
+ args.width = args.width or args.height
378
+ if args.height is not None and (args.width % 64 != 0 or args.height % 64 != 0):
379
+ raise ValueError("Image resolution must be divisible by 64.")
380
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
381
+ main(args, device)
infer.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ python infer.py \
2
+ --sdxl_path path/to/sdxl \
3
+ --vision_encoder_path path/to/dinov2_large \
4
+ --instantir_path path/to/instantir \
5
+ --test_path path/to/input \
6
+ --out_path path/to/output
losses/loss_config.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ @dataclass
5
+ class SingleLossConfig:
6
+ name: str
7
+ weight: float = 1.
8
+ init_params: dict = field(default_factory=dict)
9
+ visualize_every_k: int = -1
10
+
11
+
12
+ @dataclass
13
+ class LossesConfig:
14
+ diffusion_losses: List[SingleLossConfig]
15
+ lcm_losses: List[SingleLossConfig]
losses/losses.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import wandb
3
+ import cv2
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from facenet_pytorch import MTCNN
7
+ from torchvision import transforms
8
+ from dreamsim import dreamsim
9
+ from einops import rearrange
10
+ import kornia.augmentation as K
11
+ import lpips
12
+
13
+ from pretrained_models.arcface import Backbone
14
+ from utils.vis_utils import add_text_to_image
15
+ from utils.utils import extract_faces_and_landmarks
16
+ import clip
17
+
18
+
19
+ class Loss():
20
+ """
21
+ General purpose loss class.
22
+ Mainly handles dtype and visualize_every_k.
23
+ keeps current iteration of loss, mainly for visualization purposes.
24
+ """
25
+ def __init__(self, visualize_every_k=-1, dtype=torch.float32, accelerator=None, **kwargs):
26
+ self.visualize_every_k = visualize_every_k
27
+ self.iteration = -1
28
+ self.dtype=dtype
29
+ self.accelerator = accelerator
30
+
31
+ def __call__(self, **kwargs):
32
+ self.iteration += 1
33
+ return self.forward(**kwargs)
34
+
35
+
36
+ class L1Loss(Loss):
37
+ """
38
+ Simple L1 loss between predicted_pixel_values and pixel_values
39
+
40
+ Args:
41
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
42
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
43
+ """
44
+ def forward(
45
+ self,
46
+ predict: torch.Tensor,
47
+ target: torch.Tensor,
48
+ **kwargs
49
+ ) -> torch.Tensor:
50
+ return F.l1_loss(predict, target, reduction="mean")
51
+
52
+
53
+ class DreamSIMLoss(Loss):
54
+ """DreamSIM loss between predicted_pixel_values and pixel_values.
55
+ DreamSIM is similar to LPIPS (https://dreamsim-nights.github.io/) but is trained on more human defined similarity dataset
56
+ DreamSIM expects an RGB image of size 224x224 and values between 0 and 1. So we need to normalize the input images to 0-1 range and resize them to 224x224.
57
+ Args:
58
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
59
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
60
+ """
61
+ def __init__(self, device: str='cuda:0', **kwargs):
62
+ super().__init__(**kwargs)
63
+ self.model, _ = dreamsim(pretrained=True, device=device)
64
+ self.model.to(dtype=self.dtype, device=device)
65
+ self.model = self.accelerator.prepare(self.model)
66
+ self.transforms = transforms.Compose([
67
+ transforms.Lambda(lambda x: (x + 1) / 2),
68
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC)])
69
+
70
+ def forward(
71
+ self,
72
+ predicted_pixel_values: torch.Tensor,
73
+ encoder_pixel_values: torch.Tensor,
74
+ **kwargs,
75
+ ) -> torch.Tensor:
76
+ predicted_pixel_values.to(dtype=self.dtype)
77
+ encoder_pixel_values.to(dtype=self.dtype)
78
+ return self.model(self.transforms(predicted_pixel_values), self.transforms(encoder_pixel_values)).mean()
79
+
80
+
81
+ class LPIPSLoss(Loss):
82
+ """LPIPS loss between predicted_pixel_values and pixel_values.
83
+ Args:
84
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
85
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
86
+ """
87
+ def __init__(self, **kwargs):
88
+ super().__init__(**kwargs)
89
+ self.model = lpips.LPIPS(net='vgg')
90
+ self.model.to(dtype=self.dtype, device=self.accelerator.device)
91
+ self.model = self.accelerator.prepare(self.model)
92
+
93
+ def forward(self, predict, target, **kwargs):
94
+ predict.to(dtype=self.dtype)
95
+ target.to(dtype=self.dtype)
96
+ return self.model(predict, target).mean()
97
+
98
+
99
+ class LCMVisualization(Loss):
100
+ """Dummy loss used to visualize the LCM outputs
101
+ Args:
102
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
103
+ pixel_values (torch.Tensor): The input image to the decoder
104
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
105
+ """
106
+ def forward(
107
+ self,
108
+ predicted_pixel_values: torch.Tensor,
109
+ pixel_values: torch.Tensor,
110
+ encoder_pixel_values: torch.Tensor,
111
+ timesteps: torch.Tensor,
112
+ **kwargs,
113
+ ) -> None:
114
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
115
+ predicted_pixel_values = rearrange(predicted_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
116
+ pixel_values = rearrange(pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
117
+ encoder_pixel_values = rearrange(encoder_pixel_values, "n c h w -> (n h) w c").detach().cpu().numpy()
118
+ image = np.hstack([encoder_pixel_values, pixel_values, predicted_pixel_values])
119
+ for tracker in self.accelerator.trackers:
120
+ if tracker.name == 'wandb':
121
+ tracker.log({"TrainVisualization": wandb.Image(image, caption=f"Encoder Input Image, Decoder Input Image, Predicted LCM Image. Timesteps {timesteps.cpu().tolist()}")})
122
+ return torch.tensor(0.0)
123
+
124
+
125
+ class L2Loss(Loss):
126
+ """
127
+ Regular diffusion loss between predicted noise and target noise.
128
+
129
+ Args:
130
+ predicted_noise (torch.Tensor): noise predicted by the diffusion model
131
+ target_noise (torch.Tensor): actual noise added to the image.
132
+ """
133
+ def forward(
134
+ self,
135
+ predict: torch.Tensor,
136
+ target: torch.Tensor,
137
+ weights: torch.Tensor = None,
138
+ **kwargs
139
+ ) -> torch.Tensor:
140
+ if weights is not None:
141
+ loss = (predict.float() - target.float()).pow(2) * weights
142
+ return loss.mean()
143
+ return F.mse_loss(predict.float(), target.float(), reduction="mean")
144
+
145
+
146
+ class HuberLoss(Loss):
147
+ """Huber loss between predicted_pixel_values and pixel_values.
148
+ Args:
149
+ predicted_pixel_values (torch.Tensor): The predicted pixel values using 1 step LCM and the VAE decoder.
150
+ encoder_pixel_values (torch.Tesnor): The input image to the encoder
151
+ """
152
+ def __init__(self, huber_c=0.001, **kwargs):
153
+ super().__init__(**kwargs)
154
+ self.huber_c = huber_c
155
+
156
+ def forward(
157
+ self,
158
+ predict: torch.Tensor,
159
+ target: torch.Tensor,
160
+ weights: torch.Tensor = None,
161
+ **kwargs
162
+ ) -> torch.Tensor:
163
+ loss = torch.sqrt((predict.float() - target.float()) ** 2 + self.huber_c**2) - self.huber_c
164
+ if weights is not None:
165
+ return (loss * weights).mean()
166
+ return loss.mean()
167
+
168
+
169
+ class WeightedNoiseLoss(Loss):
170
+ """
171
+ Weighted diffusion loss between predicted noise and target noise.
172
+
173
+ Args:
174
+ predicted_noise (torch.Tensor): noise predicted by the diffusion model
175
+ target_noise (torch.Tensor): actual noise added to the image.
176
+ loss_batch_weights (torch.Tensor): weighting for each batch item. Can be used to e.g. zero-out loss for InstantID training if keypoint extraction fails.
177
+ """
178
+ def forward(
179
+ self,
180
+ predict: torch.Tensor,
181
+ target: torch.Tensor,
182
+ weights,
183
+ **kwargs
184
+ ) -> torch.Tensor:
185
+ return F.mse_loss(predict.float() * weights, target.float() * weights, reduction="mean")
186
+
187
+
188
+ class IDLoss(Loss):
189
+ """
190
+ Use pretrained facenet model to extract features from the face of the predicted image and target image.
191
+ Facenet expects 112x112 images, so we crop the face using MTCNN and resize it to 112x112.
192
+ Then we use the cosine similarity between the features to calculate the loss. (The cosine similarity is 1 - cosine distance).
193
+ Also notice that the outputs of facenet are normalized so the dot product is the same as cosine distance.
194
+ """
195
+ def __init__(self, pretrained_arcface_path: str, skip_not_found=True, **kwargs):
196
+ super().__init__(**kwargs)
197
+ assert pretrained_arcface_path is not None, "please pass `pretrained_arcface_path` in the losses config. You can download the pretrained model from "\
198
+ "https://drive.google.com/file/d/1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn/view?usp=sharing"
199
+ self.mtcnn = MTCNN(device=self.accelerator.device)
200
+ self.mtcnn.forward = self.mtcnn.detect
201
+ self.facenet_input_size = 112 # Has to be 112, can't find weights for 224 size.
202
+ self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
203
+ self.facenet.load_state_dict(torch.load(pretrained_arcface_path))
204
+ self.face_pool = torch.nn.AdaptiveAvgPool2d((self.facenet_input_size, self.facenet_input_size))
205
+ self.facenet.requires_grad_(False)
206
+ self.facenet.eval()
207
+ self.facenet.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
208
+ self.face_pool.to(device=self.accelerator.device, dtype=self.dtype) # not implemented for half precision
209
+ self.visualization_resize = transforms.Resize((self.facenet_input_size, self.facenet_input_size), interpolation=transforms.InterpolationMode.BICUBIC)
210
+ self.reference_facial_points = np.array([[38.29459953, 51.69630051],
211
+ [72.53179932, 51.50139999],
212
+ [56.02519989, 71.73660278],
213
+ [41.54930115, 92.3655014],
214
+ [70.72990036, 92.20410156]
215
+ ]) # Original points are 112 * 96 added 8 to the x axis to make it 112 * 112
216
+ self.facenet, self.face_pool, self.mtcnn = self.accelerator.prepare(self.facenet, self.face_pool, self.mtcnn)
217
+
218
+ self.skip_not_found = skip_not_found
219
+
220
+ def extract_feats(self, x: torch.Tensor):
221
+ """
222
+ Extract features from the face of the image using facenet model.
223
+ """
224
+ x = self.face_pool(x)
225
+ x_feats = self.facenet(x)
226
+
227
+ return x_feats
228
+
229
+ def forward(
230
+ self,
231
+ predicted_pixel_values: torch.Tensor,
232
+ encoder_pixel_values: torch.Tensor,
233
+ timesteps: torch.Tensor,
234
+ **kwargs
235
+ ):
236
+ encoder_pixel_values = encoder_pixel_values.to(dtype=self.dtype)
237
+ predicted_pixel_values = predicted_pixel_values.to(dtype=self.dtype)
238
+
239
+ predicted_pixel_values_face, predicted_invalid_indices = extract_faces_and_landmarks(predicted_pixel_values, mtcnn=self.mtcnn)
240
+ with torch.no_grad():
241
+ encoder_pixel_values_face, source_invalid_indices = extract_faces_and_landmarks(encoder_pixel_values, mtcnn=self.mtcnn)
242
+
243
+ if self.skip_not_found:
244
+ valid_indices = []
245
+ for i in range(predicted_pixel_values.shape[0]):
246
+ if i not in predicted_invalid_indices and i not in source_invalid_indices:
247
+ valid_indices.append(i)
248
+ else:
249
+ valid_indices = list(range(predicted_pixel_values))
250
+
251
+ valid_indices = torch.tensor(valid_indices).to(device=predicted_pixel_values.device)
252
+
253
+ if len(valid_indices) == 0:
254
+ loss = (predicted_pixel_values_face * 0.0).mean() # It's done this way so the `backwards` will delete the computation graph of the predicted_pixel_values.
255
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
256
+ self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
257
+ return loss
258
+
259
+ with torch.no_grad():
260
+ pixel_values_feats = self.extract_feats(encoder_pixel_values_face[valid_indices])
261
+
262
+ predicted_pixel_values_feats = self.extract_feats(predicted_pixel_values_face[valid_indices])
263
+ loss = 1 - torch.einsum("bi,bi->b", pixel_values_feats, predicted_pixel_values_feats)
264
+
265
+ if self.visualize_every_k > 0 and self.iteration % self.visualize_every_k == 0:
266
+ self.visualize(predicted_pixel_values, encoder_pixel_values, predicted_pixel_values_face, encoder_pixel_values_face, timesteps, valid_indices, loss)
267
+ return loss.mean()
268
+
269
+ def visualize(
270
+ self,
271
+ predicted_pixel_values: torch.Tensor,
272
+ encoder_pixel_values: torch.Tensor,
273
+ predicted_pixel_values_face: torch.Tensor,
274
+ encoder_pixel_values_face: torch.Tensor,
275
+ timesteps: torch.Tensor,
276
+ valid_indices: torch.Tensor,
277
+ loss: torch.Tensor,
278
+ ) -> None:
279
+ small_predicted_pixel_values = (rearrange(self.visualization_resize(predicted_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy())
280
+ small_pixle_values = rearrange(self.visualization_resize(encoder_pixel_values), "n c h w -> (n h) w c").detach().cpu().numpy()
281
+ small_predicted_pixel_values_face = rearrange(self.visualization_resize(predicted_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
282
+ small_pixle_values_face = rearrange(self.visualization_resize(encoder_pixel_values_face), "n c h w -> (n h) w c").detach().cpu().numpy()
283
+
284
+ small_predicted_pixel_values = add_text_to_image(((small_predicted_pixel_values * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Images", add_below=False)
285
+ small_pixle_values = add_text_to_image(((small_pixle_values * 0.5 + 0.5) * 255).astype(np.uint8), "Target Images", add_below=False)
286
+ small_predicted_pixel_values_face = add_text_to_image(((small_predicted_pixel_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Pred Faces", add_below=False)
287
+ small_pixle_values_face = add_text_to_image(((small_pixle_values_face * 0.5 + 0.5) * 255).astype(np.uint8), "Target Faces", add_below=False)
288
+
289
+
290
+ final_image = np.hstack([small_predicted_pixel_values, small_pixle_values, small_predicted_pixel_values_face, small_pixle_values_face])
291
+ for tracker in self.accelerator.trackers:
292
+ if tracker.name == 'wandb':
293
+ tracker.log({"IDLoss Visualization": wandb.Image(final_image, caption=f"loss: {loss.cpu().tolist()} timesteps: {timesteps.cpu().tolist()}, valid_indices: {valid_indices.cpu().tolist()}")})
294
+
295
+
296
+ class ImageAugmentations(torch.nn.Module):
297
+ # Standard image augmentations used for CLIP loss to discourage adversarial outputs.
298
+ def __init__(self, output_size, augmentations_number, p=0.7):
299
+ super().__init__()
300
+ self.output_size = output_size
301
+ self.augmentations_number = augmentations_number
302
+
303
+ self.augmentations = torch.nn.Sequential(
304
+ K.RandomAffine(degrees=15, translate=0.1, p=p, padding_mode="border"), # type: ignore
305
+ K.RandomPerspective(0.7, p=p),
306
+ )
307
+
308
+ self.avg_pool = torch.nn.AdaptiveAvgPool2d((self.output_size, self.output_size))
309
+
310
+ self.device = None
311
+
312
+ def forward(self, input):
313
+ """Extents the input batch with augmentations
314
+ If the input is consists of images [I1, I2] the extended augmented output
315
+ will be [I1_resized, I2_resized, I1_aug1, I2_aug1, I1_aug2, I2_aug2 ...]
316
+ Args:
317
+ input ([type]): input batch of shape [batch, C, H, W]
318
+ Returns:
319
+ updated batch: of shape [batch * augmentations_number, C, H, W]
320
+ """
321
+ # We want to multiply the number of images in the batch in contrast to regular augmantations
322
+ # that do not change the number of samples in the batch)
323
+ resized_images = self.avg_pool(input)
324
+ resized_images = torch.tile(resized_images, dims=(self.augmentations_number, 1, 1, 1))
325
+
326
+ batch_size = input.shape[0]
327
+ # We want at least one non augmented image
328
+ non_augmented_batch = resized_images[:batch_size]
329
+ augmented_batch = self.augmentations(resized_images[batch_size:])
330
+ updated_batch = torch.cat([non_augmented_batch, augmented_batch], dim=0)
331
+
332
+ return updated_batch
333
+
334
+
335
+ class CLIPLoss(Loss):
336
+ def __init__(self, augmentations_number: int = 4, **kwargs):
337
+ super().__init__(**kwargs)
338
+
339
+ self.clip_model, clip_preprocess = clip.load("ViT-B/16", device=self.accelerator.device, jit=False)
340
+
341
+ self.clip_model.device = None
342
+
343
+ self.clip_model.eval().requires_grad_(False)
344
+
345
+ self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
346
+ clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions
347
+ clip_preprocess.transforms[4:]) # + skip convert PIL to tensor
348
+
349
+ self.clip_size = self.clip_model.visual.input_resolution
350
+
351
+ self.clip_normalize = transforms.Normalize(
352
+ mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
353
+ )
354
+
355
+ self.image_augmentations = ImageAugmentations(output_size=self.clip_size,
356
+ augmentations_number=augmentations_number)
357
+
358
+ self.clip_model, self.image_augmentations = self.accelerator.prepare(self.clip_model, self.image_augmentations)
359
+
360
+ def forward(self, decoder_prompts, predicted_pixel_values: torch.Tensor, **kwargs) -> torch.Tensor:
361
+
362
+ if not isinstance(decoder_prompts, list):
363
+ decoder_prompts = [decoder_prompts]
364
+
365
+ tokens = clip.tokenize(decoder_prompts).to(predicted_pixel_values.device)
366
+ image = self.preprocess(predicted_pixel_values)
367
+
368
+ logits_per_image, _ = self.clip_model(image, tokens)
369
+
370
+ logits_per_image = torch.diagonal(logits_per_image)
371
+
372
+ return (1. - logits_per_image / 100).mean()
373
+
374
+
375
+ class DINOLoss(Loss):
376
+ def __init__(
377
+ self,
378
+ dino_model,
379
+ dino_preprocess,
380
+ output_hidden_states: bool = False,
381
+ center_momentum: float = 0.9,
382
+ student_temp: float = 0.1,
383
+ teacher_temp: float = 0.04,
384
+ warmup_teacher_temp: float = 0.04,
385
+ warmup_teacher_temp_epochs: int = 30,
386
+ **kwargs):
387
+ super().__init__(**kwargs)
388
+
389
+ self.dino_model = dino_model
390
+ self.output_hidden_states = output_hidden_states
391
+ self.rescale_factor = dino_preprocess.rescale_factor
392
+
393
+ # Un-normalize from [-1.0, 1.0] (SD output) to [0, 1].
394
+ self.preprocess = transforms.Compose(
395
+ [
396
+ transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
397
+ transforms.Resize(size=256),
398
+ transforms.CenterCrop(size=(224, 224)),
399
+ transforms.Normalize(mean=dino_preprocess.image_mean, std=dino_preprocess.image_std)
400
+ ]
401
+ )
402
+
403
+ self.student_temp = student_temp
404
+ self.teacher_temp = teacher_temp
405
+ self.center_momentum = center_momentum
406
+ self.center = torch.zeros(1, 257, 1024).to(self.accelerator.device, dtype=self.dtype)
407
+
408
+ # TODO: add temp, now fixed to 0.04
409
+ # we apply a warm up for the teacher temperature because
410
+ # a too high temperature makes the training instable at the beginning
411
+ # self.teacher_temp_schedule = np.concatenate((
412
+ # np.linspace(warmup_teacher_temp,
413
+ # teacher_temp, warmup_teacher_temp_epochs),
414
+ # np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
415
+ # ))
416
+
417
+ self.dino_model = self.accelerator.prepare(self.dino_model)
418
+
419
+ def forward(
420
+ self,
421
+ target: torch.Tensor,
422
+ predict: torch.Tensor,
423
+ weights: torch.Tensor = None,
424
+ **kwargs) -> torch.Tensor:
425
+
426
+ predict = self.preprocess(predict)
427
+ target = self.preprocess(target)
428
+
429
+ encoder_input = torch.cat([target, predict]).to(self.dino_model.device, dtype=self.dino_model.dtype)
430
+
431
+ if self.output_hidden_states:
432
+ raise ValueError("Output hidden states not supported for DINO loss.")
433
+ image_enc_hidden_states = self.dino_model(encoder_input, output_hidden_states=True).hidden_states[-2]
434
+ else:
435
+ image_enc_hidden_states = self.dino_model(encoder_input).last_hidden_state
436
+
437
+ teacher_output, student_output = image_enc_hidden_states.chunk(2, dim=0) # [B, 257, 1024]
438
+
439
+ student_out = student_output.float() / self.student_temp
440
+
441
+ # teacher centering and sharpening
442
+ # temp = self.teacher_temp_schedule[epoch]
443
+ temp = self.teacher_temp
444
+ teacher_out = F.softmax((teacher_output.float() - self.center) / temp, dim=-1)
445
+ teacher_out = teacher_out.detach()
446
+
447
+ loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1, keepdim=True)
448
+ # self.update_center(teacher_output)
449
+
450
+ if weights is not None:
451
+ loss = loss * weights
452
+ return loss.mean()
453
+ return loss.mean()
454
+
455
+ @torch.no_grad()
456
+ def update_center(self, teacher_output):
457
+ """
458
+ Update center used for teacher output.
459
+ """
460
+ batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
461
+ self.accelerator.reduce(batch_center, reduction="sum")
462
+ batch_center = batch_center / (len(teacher_output) * self.accelerator.num_processes)
463
+
464
+ # ema update
465
+ self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
module/aggregator.py ADDED
@@ -0,0 +1,983 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin
10
+ from diffusers.utils import BaseOutput, logging
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
19
+ from diffusers.models.modeling_utils import ModelMixin
20
+ from diffusers.models.unets.unet_2d_blocks import (
21
+ CrossAttnDownBlock2D,
22
+ DownBlock2D,
23
+ UNetMidBlock2D,
24
+ UNetMidBlock2DCrossAttn,
25
+ get_down_block,
26
+ )
27
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ class ZeroConv(nn.Module):
34
+ def __init__(self, label_nc, norm_nc, mask=False):
35
+ super().__init__()
36
+ self.zero_conv = zero_module(nn.Conv2d(label_nc+norm_nc, norm_nc, 1, 1, 0))
37
+ self.mask = mask
38
+
39
+ def forward(self, hidden_states, h_ori=None):
40
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
41
+ c, h = hidden_states
42
+ if not self.mask:
43
+ h = self.zero_conv(torch.cat([c, h], dim=1))
44
+ else:
45
+ h = self.zero_conv(torch.cat([c, h], dim=1)) * torch.zeros_like(h)
46
+ if h_ori is not None:
47
+ h = torch.cat([h_ori, h], dim=1)
48
+ return h
49
+
50
+
51
+ class SFT(nn.Module):
52
+ def __init__(self, label_nc, norm_nc, mask=False):
53
+ super().__init__()
54
+
55
+ # param_free_norm_type = str(parsed.group(1))
56
+ ks = 3
57
+ pw = ks // 2
58
+
59
+ self.mask = mask
60
+
61
+ nhidden = 128
62
+
63
+ self.mlp_shared = nn.Sequential(
64
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
65
+ nn.SiLU()
66
+ )
67
+ self.mul = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
68
+ self.add = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
69
+
70
+ def forward(self, hidden_states, mask=False):
71
+
72
+ c, h = hidden_states
73
+ mask = mask or self.mask
74
+ assert mask is False
75
+
76
+ actv = self.mlp_shared(c)
77
+ gamma = self.mul(actv)
78
+ beta = self.add(actv)
79
+
80
+ if self.mask:
81
+ gamma = gamma * torch.zeros_like(gamma)
82
+ beta = beta * torch.zeros_like(beta)
83
+ # gamma_ori, gamma_res = torch.split(gamma, [h_ori_c, h_c], dim=1)
84
+ # beta_ori, beta_res = torch.split(beta, [h_ori_c, h_c], dim=1)
85
+ # print(gamma_ori.mean(), gamma_res.mean(), beta_ori.mean(), beta_res.mean())
86
+ h = h * (gamma + 1) + beta
87
+ # sample_ori, sample_res = torch.split(h, [h_ori_c, h_c], dim=1)
88
+ # print(sample_ori.mean(), sample_res.mean())
89
+
90
+ return h
91
+
92
+
93
+ @dataclass
94
+ class AggregatorOutput(BaseOutput):
95
+ """
96
+ The output of [`Aggregator`].
97
+
98
+ Args:
99
+ down_block_res_samples (`tuple[torch.Tensor]`):
100
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
101
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
102
+ used to condition the original UNet's downsampling activations.
103
+ mid_down_block_re_sample (`torch.Tensor`):
104
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
105
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
106
+ Output can be used to condition the original UNet's middle block activation.
107
+ """
108
+
109
+ down_block_res_samples: Tuple[torch.Tensor]
110
+ mid_block_res_sample: torch.Tensor
111
+
112
+
113
+ class ConditioningEmbedding(nn.Module):
114
+ """
115
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
116
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
117
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
118
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
119
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
120
+ model) to encode image-space conditions ... into feature maps ..."
121
+ """
122
+
123
+ def __init__(
124
+ self,
125
+ conditioning_embedding_channels: int,
126
+ conditioning_channels: int = 3,
127
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
128
+ ):
129
+ super().__init__()
130
+
131
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
132
+
133
+ self.blocks = nn.ModuleList([])
134
+
135
+ for i in range(len(block_out_channels) - 1):
136
+ channel_in = block_out_channels[i]
137
+ channel_out = block_out_channels[i + 1]
138
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
139
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
140
+
141
+ self.conv_out = zero_module(
142
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
143
+ )
144
+
145
+ def forward(self, conditioning):
146
+ embedding = self.conv_in(conditioning)
147
+ embedding = F.silu(embedding)
148
+
149
+ for block in self.blocks:
150
+ embedding = block(embedding)
151
+ embedding = F.silu(embedding)
152
+
153
+ embedding = self.conv_out(embedding)
154
+
155
+ return embedding
156
+
157
+
158
+ class Aggregator(ModelMixin, ConfigMixin, FromOriginalModelMixin):
159
+ """
160
+ Aggregator model.
161
+
162
+ Args:
163
+ in_channels (`int`, defaults to 4):
164
+ The number of channels in the input sample.
165
+ flip_sin_to_cos (`bool`, defaults to `True`):
166
+ Whether to flip the sin to cos in the time embedding.
167
+ freq_shift (`int`, defaults to 0):
168
+ The frequency shift to apply to the time embedding.
169
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
170
+ The tuple of downsample blocks to use.
171
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
172
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
173
+ The tuple of output channels for each block.
174
+ layers_per_block (`int`, defaults to 2):
175
+ The number of layers per block.
176
+ downsample_padding (`int`, defaults to 1):
177
+ The padding to use for the downsampling convolution.
178
+ mid_block_scale_factor (`float`, defaults to 1):
179
+ The scale factor to use for the mid block.
180
+ act_fn (`str`, defaults to "silu"):
181
+ The activation function to use.
182
+ norm_num_groups (`int`, *optional*, defaults to 32):
183
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
184
+ in post-processing.
185
+ norm_eps (`float`, defaults to 1e-5):
186
+ The epsilon to use for the normalization.
187
+ cross_attention_dim (`int`, defaults to 1280):
188
+ The dimension of the cross attention features.
189
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
190
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
191
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
192
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
193
+ encoder_hid_dim (`int`, *optional*, defaults to None):
194
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
195
+ dimension to `cross_attention_dim`.
196
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
197
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
198
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
199
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
200
+ The dimension of the attention heads.
201
+ use_linear_projection (`bool`, defaults to `False`):
202
+ class_embed_type (`str`, *optional*, defaults to `None`):
203
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
204
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
205
+ addition_embed_type (`str`, *optional*, defaults to `None`):
206
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
207
+ "text". "text" will use the `TextTimeEmbedding` layer.
208
+ num_class_embeds (`int`, *optional*, defaults to 0):
209
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
210
+ class conditioning with `class_embed_type` equal to `None`.
211
+ upcast_attention (`bool`, defaults to `False`):
212
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
213
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
214
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
215
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
216
+ `class_embed_type="projection"`.
217
+ controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
218
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
219
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
220
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
221
+ global_pool_conditions (`bool`, defaults to `False`):
222
+ TODO(Patrick) - unused parameter.
223
+ addition_embed_type_num_heads (`int`, defaults to 64):
224
+ The number of heads to use for the `TextTimeEmbedding` layer.
225
+ """
226
+
227
+ _supports_gradient_checkpointing = True
228
+
229
+ @register_to_config
230
+ def __init__(
231
+ self,
232
+ in_channels: int = 4,
233
+ conditioning_channels: int = 3,
234
+ flip_sin_to_cos: bool = True,
235
+ freq_shift: int = 0,
236
+ down_block_types: Tuple[str, ...] = (
237
+ "CrossAttnDownBlock2D",
238
+ "CrossAttnDownBlock2D",
239
+ "CrossAttnDownBlock2D",
240
+ "DownBlock2D",
241
+ ),
242
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
243
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
244
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
245
+ layers_per_block: int = 2,
246
+ downsample_padding: int = 1,
247
+ mid_block_scale_factor: float = 1,
248
+ act_fn: str = "silu",
249
+ norm_num_groups: Optional[int] = 32,
250
+ norm_eps: float = 1e-5,
251
+ cross_attention_dim: int = 1280,
252
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
253
+ encoder_hid_dim: Optional[int] = None,
254
+ encoder_hid_dim_type: Optional[str] = None,
255
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
256
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
257
+ use_linear_projection: bool = False,
258
+ class_embed_type: Optional[str] = None,
259
+ addition_embed_type: Optional[str] = None,
260
+ addition_time_embed_dim: Optional[int] = None,
261
+ num_class_embeds: Optional[int] = None,
262
+ upcast_attention: bool = False,
263
+ resnet_time_scale_shift: str = "default",
264
+ projection_class_embeddings_input_dim: Optional[int] = None,
265
+ controlnet_conditioning_channel_order: str = "rgb",
266
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
267
+ global_pool_conditions: bool = False,
268
+ addition_embed_type_num_heads: int = 64,
269
+ pad_concat: bool = False,
270
+ ):
271
+ super().__init__()
272
+
273
+ # If `num_attention_heads` is not defined (which is the case for most models)
274
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
275
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
276
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
277
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
278
+ # which is why we correct for the naming here.
279
+ num_attention_heads = num_attention_heads or attention_head_dim
280
+ self.pad_concat = pad_concat
281
+
282
+ # Check inputs
283
+ if len(block_out_channels) != len(down_block_types):
284
+ raise ValueError(
285
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
286
+ )
287
+
288
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
289
+ raise ValueError(
290
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
291
+ )
292
+
293
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
294
+ raise ValueError(
295
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
296
+ )
297
+
298
+ if isinstance(transformer_layers_per_block, int):
299
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
300
+
301
+ # input
302
+ conv_in_kernel = 3
303
+ conv_in_padding = (conv_in_kernel - 1) // 2
304
+ self.conv_in = nn.Conv2d(
305
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
306
+ )
307
+
308
+ # time
309
+ time_embed_dim = block_out_channels[0] * 4
310
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
311
+ timestep_input_dim = block_out_channels[0]
312
+ self.time_embedding = TimestepEmbedding(
313
+ timestep_input_dim,
314
+ time_embed_dim,
315
+ act_fn=act_fn,
316
+ )
317
+
318
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
319
+ encoder_hid_dim_type = "text_proj"
320
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
321
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
322
+
323
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
324
+ raise ValueError(
325
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
326
+ )
327
+
328
+ if encoder_hid_dim_type == "text_proj":
329
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
330
+ elif encoder_hid_dim_type == "text_image_proj":
331
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
332
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
333
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
334
+ self.encoder_hid_proj = TextImageProjection(
335
+ text_embed_dim=encoder_hid_dim,
336
+ image_embed_dim=cross_attention_dim,
337
+ cross_attention_dim=cross_attention_dim,
338
+ )
339
+
340
+ elif encoder_hid_dim_type is not None:
341
+ raise ValueError(
342
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
343
+ )
344
+ else:
345
+ self.encoder_hid_proj = None
346
+
347
+ # class embedding
348
+ if class_embed_type is None and num_class_embeds is not None:
349
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
350
+ elif class_embed_type == "timestep":
351
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
352
+ elif class_embed_type == "identity":
353
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
354
+ elif class_embed_type == "projection":
355
+ if projection_class_embeddings_input_dim is None:
356
+ raise ValueError(
357
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
358
+ )
359
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
360
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
361
+ # 2. it projects from an arbitrary input dimension.
362
+ #
363
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
364
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
365
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
366
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
367
+ else:
368
+ self.class_embedding = None
369
+
370
+ if addition_embed_type == "text":
371
+ if encoder_hid_dim is not None:
372
+ text_time_embedding_from_dim = encoder_hid_dim
373
+ else:
374
+ text_time_embedding_from_dim = cross_attention_dim
375
+
376
+ self.add_embedding = TextTimeEmbedding(
377
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
378
+ )
379
+ elif addition_embed_type == "text_image":
380
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
381
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
382
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
383
+ self.add_embedding = TextImageTimeEmbedding(
384
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
385
+ )
386
+ elif addition_embed_type == "text_time":
387
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
388
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
389
+
390
+ elif addition_embed_type is not None:
391
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
392
+
393
+ # control net conditioning embedding
394
+ self.ref_conv_in = nn.Conv2d(
395
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
396
+ )
397
+
398
+ self.down_blocks = nn.ModuleList([])
399
+ self.controlnet_down_blocks = nn.ModuleList([])
400
+
401
+ if isinstance(only_cross_attention, bool):
402
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
403
+
404
+ if isinstance(attention_head_dim, int):
405
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
406
+
407
+ if isinstance(num_attention_heads, int):
408
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
409
+
410
+ # down
411
+ output_channel = block_out_channels[0]
412
+
413
+ # controlnet_block = ZeroConv(output_channel, output_channel)
414
+ controlnet_block = nn.Sequential(
415
+ SFT(output_channel, output_channel),
416
+ zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
417
+ )
418
+ self.controlnet_down_blocks.append(controlnet_block)
419
+
420
+ for i, down_block_type in enumerate(down_block_types):
421
+ input_channel = output_channel
422
+ output_channel = block_out_channels[i]
423
+ is_final_block = i == len(block_out_channels) - 1
424
+
425
+ down_block = get_down_block(
426
+ down_block_type,
427
+ num_layers=layers_per_block,
428
+ transformer_layers_per_block=transformer_layers_per_block[i],
429
+ in_channels=input_channel,
430
+ out_channels=output_channel,
431
+ temb_channels=time_embed_dim,
432
+ add_downsample=not is_final_block,
433
+ resnet_eps=norm_eps,
434
+ resnet_act_fn=act_fn,
435
+ resnet_groups=norm_num_groups,
436
+ cross_attention_dim=cross_attention_dim,
437
+ num_attention_heads=num_attention_heads[i],
438
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
439
+ downsample_padding=downsample_padding,
440
+ use_linear_projection=use_linear_projection,
441
+ only_cross_attention=only_cross_attention[i],
442
+ upcast_attention=upcast_attention,
443
+ resnet_time_scale_shift=resnet_time_scale_shift,
444
+ )
445
+ self.down_blocks.append(down_block)
446
+
447
+ for _ in range(layers_per_block):
448
+ # controlnet_block = ZeroConv(output_channel, output_channel)
449
+ controlnet_block = nn.Sequential(
450
+ SFT(output_channel, output_channel),
451
+ zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
452
+ )
453
+ self.controlnet_down_blocks.append(controlnet_block)
454
+
455
+ if not is_final_block:
456
+ # controlnet_block = ZeroConv(output_channel, output_channel)
457
+ controlnet_block = nn.Sequential(
458
+ SFT(output_channel, output_channel),
459
+ zero_module(nn.Conv2d(output_channel, output_channel, kernel_size=1))
460
+ )
461
+ self.controlnet_down_blocks.append(controlnet_block)
462
+
463
+ # mid
464
+ mid_block_channel = block_out_channels[-1]
465
+
466
+ # controlnet_block = ZeroConv(mid_block_channel, mid_block_channel)
467
+ controlnet_block = nn.Sequential(
468
+ SFT(mid_block_channel, mid_block_channel),
469
+ zero_module(nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1))
470
+ )
471
+ self.controlnet_mid_block = controlnet_block
472
+
473
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
474
+ self.mid_block = UNetMidBlock2DCrossAttn(
475
+ transformer_layers_per_block=transformer_layers_per_block[-1],
476
+ in_channels=mid_block_channel,
477
+ temb_channels=time_embed_dim,
478
+ resnet_eps=norm_eps,
479
+ resnet_act_fn=act_fn,
480
+ output_scale_factor=mid_block_scale_factor,
481
+ resnet_time_scale_shift=resnet_time_scale_shift,
482
+ cross_attention_dim=cross_attention_dim,
483
+ num_attention_heads=num_attention_heads[-1],
484
+ resnet_groups=norm_num_groups,
485
+ use_linear_projection=use_linear_projection,
486
+ upcast_attention=upcast_attention,
487
+ )
488
+ elif mid_block_type == "UNetMidBlock2D":
489
+ self.mid_block = UNetMidBlock2D(
490
+ in_channels=block_out_channels[-1],
491
+ temb_channels=time_embed_dim,
492
+ num_layers=0,
493
+ resnet_eps=norm_eps,
494
+ resnet_act_fn=act_fn,
495
+ output_scale_factor=mid_block_scale_factor,
496
+ resnet_groups=norm_num_groups,
497
+ resnet_time_scale_shift=resnet_time_scale_shift,
498
+ add_attention=False,
499
+ )
500
+ else:
501
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
502
+
503
+ @classmethod
504
+ def from_unet(
505
+ cls,
506
+ unet: UNet2DConditionModel,
507
+ controlnet_conditioning_channel_order: str = "rgb",
508
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
509
+ load_weights_from_unet: bool = True,
510
+ conditioning_channels: int = 3,
511
+ ):
512
+ r"""
513
+ Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`].
514
+
515
+ Parameters:
516
+ unet (`UNet2DConditionModel`):
517
+ The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied
518
+ where applicable.
519
+ """
520
+ transformer_layers_per_block = (
521
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
522
+ )
523
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
524
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
525
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
526
+ addition_time_embed_dim = (
527
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
528
+ )
529
+
530
+ controlnet = cls(
531
+ encoder_hid_dim=encoder_hid_dim,
532
+ encoder_hid_dim_type=encoder_hid_dim_type,
533
+ addition_embed_type=addition_embed_type,
534
+ addition_time_embed_dim=addition_time_embed_dim,
535
+ transformer_layers_per_block=transformer_layers_per_block,
536
+ in_channels=unet.config.in_channels,
537
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
538
+ freq_shift=unet.config.freq_shift,
539
+ down_block_types=unet.config.down_block_types,
540
+ only_cross_attention=unet.config.only_cross_attention,
541
+ block_out_channels=unet.config.block_out_channels,
542
+ layers_per_block=unet.config.layers_per_block,
543
+ downsample_padding=unet.config.downsample_padding,
544
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
545
+ act_fn=unet.config.act_fn,
546
+ norm_num_groups=unet.config.norm_num_groups,
547
+ norm_eps=unet.config.norm_eps,
548
+ cross_attention_dim=unet.config.cross_attention_dim,
549
+ attention_head_dim=unet.config.attention_head_dim,
550
+ num_attention_heads=unet.config.num_attention_heads,
551
+ use_linear_projection=unet.config.use_linear_projection,
552
+ class_embed_type=unet.config.class_embed_type,
553
+ num_class_embeds=unet.config.num_class_embeds,
554
+ upcast_attention=unet.config.upcast_attention,
555
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
556
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
557
+ mid_block_type=unet.config.mid_block_type,
558
+ controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
559
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
560
+ conditioning_channels=conditioning_channels,
561
+ )
562
+
563
+ if load_weights_from_unet:
564
+ controlnet.conv_in.load_state_dict(unet.conv_in.state_dict())
565
+ controlnet.ref_conv_in.load_state_dict(unet.conv_in.state_dict())
566
+ controlnet.time_proj.load_state_dict(unet.time_proj.state_dict())
567
+ controlnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
568
+
569
+ if controlnet.class_embedding:
570
+ controlnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
571
+
572
+ if hasattr(controlnet, "add_embedding"):
573
+ controlnet.add_embedding.load_state_dict(unet.add_embedding.state_dict())
574
+
575
+ controlnet.down_blocks.load_state_dict(unet.down_blocks.state_dict())
576
+ controlnet.mid_block.load_state_dict(unet.mid_block.state_dict())
577
+
578
+ return controlnet
579
+
580
+ @property
581
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
582
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
583
+ r"""
584
+ Returns:
585
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
586
+ indexed by its weight name.
587
+ """
588
+ # set recursively
589
+ processors = {}
590
+
591
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
592
+ if hasattr(module, "get_processor"):
593
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
594
+
595
+ for sub_name, child in module.named_children():
596
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
597
+
598
+ return processors
599
+
600
+ for name, module in self.named_children():
601
+ fn_recursive_add_processors(name, module, processors)
602
+
603
+ return processors
604
+
605
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
606
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
607
+ r"""
608
+ Sets the attention processor to use to compute attention.
609
+
610
+ Parameters:
611
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
612
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
613
+ for **all** `Attention` layers.
614
+
615
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
616
+ processor. This is strongly recommended when setting trainable attention processors.
617
+
618
+ """
619
+ count = len(self.attn_processors.keys())
620
+
621
+ if isinstance(processor, dict) and len(processor) != count:
622
+ raise ValueError(
623
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
624
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
625
+ )
626
+
627
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
628
+ if hasattr(module, "set_processor"):
629
+ if not isinstance(processor, dict):
630
+ module.set_processor(processor)
631
+ else:
632
+ module.set_processor(processor.pop(f"{name}.processor"))
633
+
634
+ for sub_name, child in module.named_children():
635
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
636
+
637
+ for name, module in self.named_children():
638
+ fn_recursive_attn_processor(name, module, processor)
639
+
640
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
641
+ def set_default_attn_processor(self):
642
+ """
643
+ Disables custom attention processors and sets the default attention implementation.
644
+ """
645
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
646
+ processor = AttnAddedKVProcessor()
647
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
648
+ processor = AttnProcessor()
649
+ else:
650
+ raise ValueError(
651
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
652
+ )
653
+
654
+ self.set_attn_processor(processor)
655
+
656
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
657
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
658
+ r"""
659
+ Enable sliced attention computation.
660
+
661
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
662
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
663
+
664
+ Args:
665
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
666
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
667
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
668
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
669
+ must be a multiple of `slice_size`.
670
+ """
671
+ sliceable_head_dims = []
672
+
673
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
674
+ if hasattr(module, "set_attention_slice"):
675
+ sliceable_head_dims.append(module.sliceable_head_dim)
676
+
677
+ for child in module.children():
678
+ fn_recursive_retrieve_sliceable_dims(child)
679
+
680
+ # retrieve number of attention layers
681
+ for module in self.children():
682
+ fn_recursive_retrieve_sliceable_dims(module)
683
+
684
+ num_sliceable_layers = len(sliceable_head_dims)
685
+
686
+ if slice_size == "auto":
687
+ # half the attention head size is usually a good trade-off between
688
+ # speed and memory
689
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
690
+ elif slice_size == "max":
691
+ # make smallest slice possible
692
+ slice_size = num_sliceable_layers * [1]
693
+
694
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
695
+
696
+ if len(slice_size) != len(sliceable_head_dims):
697
+ raise ValueError(
698
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
699
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
700
+ )
701
+
702
+ for i in range(len(slice_size)):
703
+ size = slice_size[i]
704
+ dim = sliceable_head_dims[i]
705
+ if size is not None and size > dim:
706
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
707
+
708
+ # Recursively walk through all the children.
709
+ # Any children which exposes the set_attention_slice method
710
+ # gets the message
711
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
712
+ if hasattr(module, "set_attention_slice"):
713
+ module.set_attention_slice(slice_size.pop())
714
+
715
+ for child in module.children():
716
+ fn_recursive_set_attention_slice(child, slice_size)
717
+
718
+ reversed_slice_size = list(reversed(slice_size))
719
+ for module in self.children():
720
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
721
+
722
+ def process_encoder_hidden_states(
723
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
724
+ ) -> torch.Tensor:
725
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
726
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
727
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
728
+ # Kandinsky 2.1 - style
729
+ if "image_embeds" not in added_cond_kwargs:
730
+ raise ValueError(
731
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
732
+ )
733
+
734
+ image_embeds = added_cond_kwargs.get("image_embeds")
735
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
736
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
737
+ # Kandinsky 2.2 - style
738
+ if "image_embeds" not in added_cond_kwargs:
739
+ raise ValueError(
740
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
741
+ )
742
+ image_embeds = added_cond_kwargs.get("image_embeds")
743
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
744
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
745
+ if "image_embeds" not in added_cond_kwargs:
746
+ raise ValueError(
747
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
748
+ )
749
+ image_embeds = added_cond_kwargs.get("image_embeds")
750
+ image_embeds = self.encoder_hid_proj(image_embeds)
751
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
752
+ return encoder_hidden_states
753
+
754
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
755
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
756
+ module.gradient_checkpointing = value
757
+
758
+ def forward(
759
+ self,
760
+ sample: torch.FloatTensor,
761
+ timestep: Union[torch.Tensor, float, int],
762
+ encoder_hidden_states: torch.Tensor,
763
+ controlnet_cond: torch.FloatTensor,
764
+ cat_dim: int = -2,
765
+ conditioning_scale: float = 1.0,
766
+ class_labels: Optional[torch.Tensor] = None,
767
+ timestep_cond: Optional[torch.Tensor] = None,
768
+ attention_mask: Optional[torch.Tensor] = None,
769
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
770
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
771
+ return_dict: bool = True,
772
+ ) -> Union[AggregatorOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
773
+ """
774
+ The [`Aggregator`] forward method.
775
+
776
+ Args:
777
+ sample (`torch.FloatTensor`):
778
+ The noisy input tensor.
779
+ timestep (`Union[torch.Tensor, float, int]`):
780
+ The number of timesteps to denoise an input.
781
+ encoder_hidden_states (`torch.Tensor`):
782
+ The encoder hidden states.
783
+ controlnet_cond (`torch.FloatTensor`):
784
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
785
+ conditioning_scale (`float`, defaults to `1.0`):
786
+ The scale factor for ControlNet outputs.
787
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
788
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
789
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
790
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
791
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
792
+ embeddings.
793
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
794
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
795
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
796
+ negative values to the attention scores corresponding to "discard" tokens.
797
+ added_cond_kwargs (`dict`):
798
+ Additional conditions for the Stable Diffusion XL UNet.
799
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
800
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
801
+ return_dict (`bool`, defaults to `True`):
802
+ Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
803
+
804
+ Returns:
805
+ [`~models.controlnet.ControlNetOutput`] **or** `tuple`:
806
+ If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
807
+ returned where the first element is the sample tensor.
808
+ """
809
+ # check channel order
810
+ channel_order = self.config.controlnet_conditioning_channel_order
811
+
812
+ if channel_order == "rgb":
813
+ # in rgb order by default
814
+ ...
815
+ else:
816
+ raise ValueError(f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
817
+
818
+ # prepare attention_mask
819
+ if attention_mask is not None:
820
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
821
+ attention_mask = attention_mask.unsqueeze(1)
822
+
823
+ # 1. time
824
+ timesteps = timestep
825
+ if not torch.is_tensor(timesteps):
826
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
827
+ # This would be a good case for the `match` statement (Python 3.10+)
828
+ is_mps = sample.device.type == "mps"
829
+ if isinstance(timestep, float):
830
+ dtype = torch.float32 if is_mps else torch.float64
831
+ else:
832
+ dtype = torch.int32 if is_mps else torch.int64
833
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
834
+ elif len(timesteps.shape) == 0:
835
+ timesteps = timesteps[None].to(sample.device)
836
+
837
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
838
+ timesteps = timesteps.expand(sample.shape[0])
839
+
840
+ t_emb = self.time_proj(timesteps)
841
+
842
+ # timesteps does not contain any weights and will always return f32 tensors
843
+ # but time_embedding might actually be running in fp16. so we need to cast here.
844
+ # there might be better ways to encapsulate this.
845
+ t_emb = t_emb.to(dtype=sample.dtype)
846
+
847
+ emb = self.time_embedding(t_emb, timestep_cond)
848
+ aug_emb = None
849
+
850
+ if self.class_embedding is not None:
851
+ if class_labels is None:
852
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
853
+
854
+ if self.config.class_embed_type == "timestep":
855
+ class_labels = self.time_proj(class_labels)
856
+
857
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
858
+ emb = emb + class_emb
859
+
860
+ if self.config.addition_embed_type is not None:
861
+ if self.config.addition_embed_type == "text":
862
+ aug_emb = self.add_embedding(encoder_hidden_states)
863
+
864
+ elif self.config.addition_embed_type == "text_time":
865
+ if "text_embeds" not in added_cond_kwargs:
866
+ raise ValueError(
867
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
868
+ )
869
+ text_embeds = added_cond_kwargs.get("text_embeds")
870
+ if "time_ids" not in added_cond_kwargs:
871
+ raise ValueError(
872
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
873
+ )
874
+ time_ids = added_cond_kwargs.get("time_ids")
875
+ time_embeds = self.add_time_proj(time_ids.flatten())
876
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
877
+
878
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
879
+ add_embeds = add_embeds.to(emb.dtype)
880
+ aug_emb = self.add_embedding(add_embeds)
881
+
882
+ emb = emb + aug_emb if aug_emb is not None else emb
883
+
884
+ encoder_hidden_states = self.process_encoder_hidden_states(
885
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
886
+ )
887
+
888
+ # 2. prepare input
889
+ cond_latent = self.conv_in(sample)
890
+ ref_latent = self.ref_conv_in(controlnet_cond)
891
+ batch_size, channel, height, width = cond_latent.shape
892
+ if self.pad_concat:
893
+ if cat_dim == -2 or cat_dim == 2:
894
+ concat_pad = torch.zeros(batch_size, channel, 1, width)
895
+ elif cat_dim == -1 or cat_dim == 3:
896
+ concat_pad = torch.zeros(batch_size, channel, height, 1)
897
+ else:
898
+ raise ValueError(f"Aggregator shall concat along spatial dimension, but is asked to concat dim: {cat_dim}.")
899
+ concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype)
900
+ sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim)
901
+ else:
902
+ sample = torch.cat([cond_latent, ref_latent], dim=cat_dim)
903
+
904
+ # 3. down
905
+ down_block_res_samples = (sample,)
906
+ for downsample_block in self.down_blocks:
907
+ sample, res_samples = downsample_block(
908
+ hidden_states=sample,
909
+ temb=emb,
910
+ cross_attention_kwargs=cross_attention_kwargs,
911
+ )
912
+
913
+ # rebuild sample: split and concat
914
+ if self.pad_concat:
915
+ batch_size, channel, height, width = sample.shape
916
+ if cat_dim == -2 or cat_dim == 2:
917
+ cond_latent = sample[:, :, :height//2, :]
918
+ ref_latent = sample[:, :, -(height//2):, :]
919
+ concat_pad = torch.zeros(batch_size, channel, 1, width)
920
+ elif cat_dim == -1 or cat_dim == 3:
921
+ cond_latent = sample[:, :, :, :width//2]
922
+ ref_latent = sample[:, :, :, -(width//2):]
923
+ concat_pad = torch.zeros(batch_size, channel, height, 1)
924
+ concat_pad = concat_pad.to(cond_latent.device, dtype=cond_latent.dtype)
925
+ sample = torch.cat([cond_latent, concat_pad, ref_latent], dim=cat_dim)
926
+ res_samples = res_samples[:-1] + (sample,)
927
+
928
+ down_block_res_samples += res_samples
929
+
930
+ # 4. mid
931
+ if self.mid_block is not None:
932
+ sample = self.mid_block(
933
+ sample,
934
+ emb,
935
+ cross_attention_kwargs=cross_attention_kwargs,
936
+ )
937
+
938
+ # 5. split samples and SFT.
939
+ controlnet_down_block_res_samples = ()
940
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
941
+ batch_size, channel, height, width = down_block_res_sample.shape
942
+ if cat_dim == -2 or cat_dim == 2:
943
+ cond_latent = down_block_res_sample[:, :, :height//2, :]
944
+ ref_latent = down_block_res_sample[:, :, -(height//2):, :]
945
+ elif cat_dim == -1 or cat_dim == 3:
946
+ cond_latent = down_block_res_sample[:, :, :, :width//2]
947
+ ref_latent = down_block_res_sample[:, :, :, -(width//2):]
948
+ down_block_res_sample = controlnet_block((cond_latent, ref_latent), )
949
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + (down_block_res_sample,)
950
+
951
+ down_block_res_samples = controlnet_down_block_res_samples
952
+
953
+ batch_size, channel, height, width = sample.shape
954
+ if cat_dim == -2 or cat_dim == 2:
955
+ cond_latent = sample[:, :, :height//2, :]
956
+ ref_latent = sample[:, :, -(height//2):, :]
957
+ elif cat_dim == -1 or cat_dim == 3:
958
+ cond_latent = sample[:, :, :, :width//2]
959
+ ref_latent = sample[:, :, :, -(width//2):]
960
+ mid_block_res_sample = self.controlnet_mid_block((cond_latent, ref_latent), )
961
+
962
+ # 6. scaling
963
+ down_block_res_samples = [sample*conditioning_scale for sample in down_block_res_samples]
964
+ mid_block_res_sample = mid_block_res_sample*conditioning_scale
965
+
966
+ if self.config.global_pool_conditions:
967
+ down_block_res_samples = [
968
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
969
+ ]
970
+ mid_block_res_sample = torch.mean(mid_block_res_sample, dim=(2, 3), keepdim=True)
971
+
972
+ if not return_dict:
973
+ return (down_block_res_samples, mid_block_res_sample)
974
+
975
+ return AggregatorOutput(
976
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
977
+ )
978
+
979
+
980
+ def zero_module(module):
981
+ for p in module.parameters():
982
+ nn.init.zeros_(p)
983
+ return module
module/attention.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from diffusers.models.attention.py
2
+
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import nn
21
+
22
+ from diffusers.utils import deprecate, logging
23
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
24
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
25
+ from diffusers.models.attention_processor import Attention
26
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
27
+ from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
28
+
29
+ from module.min_sdxl import LoRACompatibleLinear, LoRALinearLayer
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+ def create_custom_forward(module):
35
+ def custom_forward(*inputs):
36
+ return module(*inputs)
37
+
38
+ return custom_forward
39
+
40
+ def maybe_grad_checkpoint(resnet, attn, hidden_states, temb, encoder_hidden_states, adapter_hidden_states, do_ckpt=True):
41
+
42
+ if do_ckpt:
43
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
44
+ hidden_states, extracted_kv = torch.utils.checkpoint.checkpoint(
45
+ create_custom_forward(attn), hidden_states, encoder_hidden_states, adapter_hidden_states, use_reentrant=False
46
+ )
47
+ else:
48
+ hidden_states = resnet(hidden_states, temb)
49
+ hidden_states, extracted_kv = attn(
50
+ hidden_states,
51
+ encoder_hidden_states=encoder_hidden_states,
52
+ adapter_hidden_states=adapter_hidden_states,
53
+ )
54
+ return hidden_states, extracted_kv
55
+
56
+
57
+ def init_lora_in_attn(attn_module, rank: int = 4, is_kvcopy=False):
58
+ # Set the `lora_layer` attribute of the attention-related matrices.
59
+
60
+ attn_module.to_k.set_lora_layer(
61
+ LoRALinearLayer(
62
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=rank
63
+ )
64
+ )
65
+ attn_module.to_v.set_lora_layer(
66
+ LoRALinearLayer(
67
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=rank
68
+ )
69
+ )
70
+
71
+ if not is_kvcopy:
72
+ attn_module.to_q.set_lora_layer(
73
+ LoRALinearLayer(
74
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=rank
75
+ )
76
+ )
77
+
78
+ attn_module.to_out[0].set_lora_layer(
79
+ LoRALinearLayer(
80
+ in_features=attn_module.to_out[0].in_features,
81
+ out_features=attn_module.to_out[0].out_features,
82
+ rank=rank,
83
+ )
84
+ )
85
+
86
+ def drop_kvs(encoder_kvs, drop_chance):
87
+ for layer in encoder_kvs:
88
+ len_tokens = encoder_kvs[layer].self_attention.k.shape[1]
89
+ idx_to_keep = (torch.rand(len_tokens) > drop_chance)
90
+
91
+ encoder_kvs[layer].self_attention.k = encoder_kvs[layer].self_attention.k[:, idx_to_keep]
92
+ encoder_kvs[layer].self_attention.v = encoder_kvs[layer].self_attention.v[:, idx_to_keep]
93
+
94
+ return encoder_kvs
95
+
96
+ def clone_kvs(encoder_kvs):
97
+ cloned_kvs = {}
98
+ for layer in encoder_kvs:
99
+ sa_cpy = KVCache(k=encoder_kvs[layer].self_attention.k.clone(),
100
+ v=encoder_kvs[layer].self_attention.v.clone())
101
+
102
+ ca_cpy = KVCache(k=encoder_kvs[layer].cross_attention.k.clone(),
103
+ v=encoder_kvs[layer].cross_attention.v.clone())
104
+
105
+ cloned_layer_cache = AttentionCache(self_attention=sa_cpy, cross_attention=ca_cpy)
106
+
107
+ cloned_kvs[layer] = cloned_layer_cache
108
+
109
+ return cloned_kvs
110
+
111
+
112
+ class KVCache(object):
113
+ def __init__(self, k, v):
114
+ self.k = k
115
+ self.v = v
116
+
117
+ class AttentionCache(object):
118
+ def __init__(self, self_attention: KVCache, cross_attention: KVCache):
119
+ self.self_attention = self_attention
120
+ self.cross_attention = cross_attention
121
+
122
+ class KVCopy(nn.Module):
123
+ def __init__(
124
+ self, inner_dim, cross_attention_dim=None,
125
+ ):
126
+ super(KVCopy, self).__init__()
127
+
128
+ in_dim = cross_attention_dim or inner_dim
129
+
130
+ self.to_k = LoRACompatibleLinear(in_dim, inner_dim, bias=False)
131
+ self.to_v = LoRACompatibleLinear(in_dim, inner_dim, bias=False)
132
+
133
+ def forward(self, hidden_states):
134
+
135
+ k = self.to_k(hidden_states)
136
+ v = self.to_v(hidden_states)
137
+
138
+ return KVCache(k=k, v=v)
139
+
140
+ def init_kv_copy(self, source_attn):
141
+ with torch.no_grad():
142
+ self.to_k.weight.copy_(source_attn.to_k.weight)
143
+ self.to_v.weight.copy_(source_attn.to_v.weight)
144
+
145
+
146
+ class FeedForward(nn.Module):
147
+ r"""
148
+ A feed-forward layer.
149
+
150
+ Parameters:
151
+ dim (`int`): The number of channels in the input.
152
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
153
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
154
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
155
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
156
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
157
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
158
+ """
159
+
160
+ def __init__(
161
+ self,
162
+ dim: int,
163
+ dim_out: Optional[int] = None,
164
+ mult: int = 4,
165
+ dropout: float = 0.0,
166
+ activation_fn: str = "geglu",
167
+ final_dropout: bool = False,
168
+ inner_dim=None,
169
+ bias: bool = True,
170
+ ):
171
+ super().__init__()
172
+ if inner_dim is None:
173
+ inner_dim = int(dim * mult)
174
+ dim_out = dim_out if dim_out is not None else dim
175
+
176
+ if activation_fn == "gelu":
177
+ act_fn = GELU(dim, inner_dim, bias=bias)
178
+ if activation_fn == "gelu-approximate":
179
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
180
+ elif activation_fn == "geglu":
181
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
182
+ elif activation_fn == "geglu-approximate":
183
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
184
+
185
+ self.net = nn.ModuleList([])
186
+ # project in
187
+ self.net.append(act_fn)
188
+ # project dropout
189
+ self.net.append(nn.Dropout(dropout))
190
+ # project out
191
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
192
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
193
+ if final_dropout:
194
+ self.net.append(nn.Dropout(dropout))
195
+
196
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
197
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
198
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
199
+ deprecate("scale", "1.0.0", deprecation_message)
200
+ for module in self.net:
201
+ hidden_states = module(hidden_states)
202
+ return hidden_states
203
+
204
+
205
+ def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
206
+ # "feed_forward_chunk_size" can be used to save memory
207
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
208
+ raise ValueError(
209
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
210
+ )
211
+
212
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
213
+ ff_output = torch.cat(
214
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
215
+ dim=chunk_dim,
216
+ )
217
+ return ff_output
218
+
219
+
220
+ @maybe_allow_in_graph
221
+ class GatedSelfAttentionDense(nn.Module):
222
+ r"""
223
+ A gated self-attention dense layer that combines visual features and object features.
224
+
225
+ Parameters:
226
+ query_dim (`int`): The number of channels in the query.
227
+ context_dim (`int`): The number of channels in the context.
228
+ n_heads (`int`): The number of heads to use for attention.
229
+ d_head (`int`): The number of channels in each head.
230
+ """
231
+
232
+ def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
233
+ super().__init__()
234
+
235
+ # we need a linear projection since we need cat visual feature and obj feature
236
+ self.linear = nn.Linear(context_dim, query_dim)
237
+
238
+ self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
239
+ self.ff = FeedForward(query_dim, activation_fn="geglu")
240
+
241
+ self.norm1 = nn.LayerNorm(query_dim)
242
+ self.norm2 = nn.LayerNorm(query_dim)
243
+
244
+ self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
245
+ self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
246
+
247
+ self.enabled = True
248
+
249
+ def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
250
+ if not self.enabled:
251
+ return x
252
+
253
+ n_visual = x.shape[1]
254
+ objs = self.linear(objs)
255
+
256
+ x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
257
+ x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
258
+
259
+ return x
module/diffusers_vae/autoencoder_kl.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, Optional, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.loaders import FromOriginalVAEMixin
21
+ from diffusers.utils.accelerate_utils import apply_forward_hook
22
+ from diffusers.models.attention_processor import (
23
+ ADDED_KV_ATTENTION_PROCESSORS,
24
+ CROSS_ATTENTION_PROCESSORS,
25
+ Attention,
26
+ AttentionProcessor,
27
+ AttnAddedKVProcessor,
28
+ AttnProcessor,
29
+ )
30
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
33
+
34
+
35
+ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
36
+ r"""
37
+ A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
38
+
39
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
40
+ for all models (such as downloading or saving).
41
+
42
+ Parameters:
43
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
44
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
45
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
46
+ Tuple of downsample block types.
47
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
48
+ Tuple of upsample block types.
49
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
50
+ Tuple of block output channels.
51
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
52
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
53
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
54
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
55
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
56
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
57
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
58
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
59
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
60
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
61
+ force_upcast (`bool`, *optional*, default to `True`):
62
+ If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
63
+ can be fine-tuned / trained to a lower range without loosing too much precision in which case
64
+ `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
65
+ """
66
+
67
+ _supports_gradient_checkpointing = True
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ in_channels: int = 3,
73
+ out_channels: int = 3,
74
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
75
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
76
+ block_out_channels: Tuple[int] = (64,),
77
+ layers_per_block: int = 1,
78
+ act_fn: str = "silu",
79
+ latent_channels: int = 4,
80
+ norm_num_groups: int = 32,
81
+ sample_size: int = 32,
82
+ scaling_factor: float = 0.18215,
83
+ force_upcast: float = True,
84
+ ):
85
+ super().__init__()
86
+
87
+ # pass init params to Encoder
88
+ self.encoder = Encoder(
89
+ in_channels=in_channels,
90
+ out_channels=latent_channels,
91
+ down_block_types=down_block_types,
92
+ block_out_channels=block_out_channels,
93
+ layers_per_block=layers_per_block,
94
+ act_fn=act_fn,
95
+ norm_num_groups=norm_num_groups,
96
+ double_z=True,
97
+ )
98
+
99
+ # pass init params to Decoder
100
+ self.decoder = Decoder(
101
+ in_channels=latent_channels,
102
+ out_channels=out_channels,
103
+ up_block_types=up_block_types,
104
+ block_out_channels=block_out_channels,
105
+ layers_per_block=layers_per_block,
106
+ norm_num_groups=norm_num_groups,
107
+ act_fn=act_fn,
108
+ )
109
+
110
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
111
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
112
+
113
+ self.use_slicing = False
114
+ self.use_tiling = False
115
+
116
+ # only relevant if vae tiling is enabled
117
+ self.tile_sample_min_size = self.config.sample_size
118
+ sample_size = (
119
+ self.config.sample_size[0]
120
+ if isinstance(self.config.sample_size, (list, tuple))
121
+ else self.config.sample_size
122
+ )
123
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
124
+ self.tile_overlap_factor = 0.25
125
+
126
+ def _set_gradient_checkpointing(self, module, value=False):
127
+ if isinstance(module, (Encoder, Decoder)):
128
+ module.gradient_checkpointing = value
129
+
130
+ def enable_tiling(self, use_tiling: bool = True):
131
+ r"""
132
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
133
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
134
+ processing larger images.
135
+ """
136
+ self.use_tiling = use_tiling
137
+
138
+ def disable_tiling(self):
139
+ r"""
140
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
141
+ decoding in one step.
142
+ """
143
+ self.enable_tiling(False)
144
+
145
+ def enable_slicing(self):
146
+ r"""
147
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
148
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
149
+ """
150
+ self.use_slicing = True
151
+
152
+ def disable_slicing(self):
153
+ r"""
154
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
155
+ decoding in one step.
156
+ """
157
+ self.use_slicing = False
158
+
159
+ @property
160
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
161
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
162
+ r"""
163
+ Returns:
164
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
165
+ indexed by its weight name.
166
+ """
167
+ # set recursively
168
+ processors = {}
169
+
170
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
171
+ if hasattr(module, "get_processor"):
172
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
173
+
174
+ for sub_name, child in module.named_children():
175
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
176
+
177
+ return processors
178
+
179
+ for name, module in self.named_children():
180
+ fn_recursive_add_processors(name, module, processors)
181
+
182
+ return processors
183
+
184
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
185
+ def set_attn_processor(
186
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
187
+ ):
188
+ r"""
189
+ Sets the attention processor to use to compute attention.
190
+
191
+ Parameters:
192
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
193
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
194
+ for **all** `Attention` layers.
195
+
196
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
197
+ processor. This is strongly recommended when setting trainable attention processors.
198
+
199
+ """
200
+ count = len(self.attn_processors.keys())
201
+
202
+ if isinstance(processor, dict) and len(processor) != count:
203
+ raise ValueError(
204
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
205
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
206
+ )
207
+
208
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
209
+ if hasattr(module, "set_processor"):
210
+ if not isinstance(processor, dict):
211
+ module.set_processor(processor, _remove_lora=_remove_lora)
212
+ else:
213
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
214
+
215
+ for sub_name, child in module.named_children():
216
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
217
+
218
+ for name, module in self.named_children():
219
+ fn_recursive_attn_processor(name, module, processor)
220
+
221
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
222
+ def set_default_attn_processor(self):
223
+ """
224
+ Disables custom attention processors and sets the default attention implementation.
225
+ """
226
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
227
+ processor = AttnAddedKVProcessor()
228
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
229
+ processor = AttnProcessor()
230
+ else:
231
+ raise ValueError(
232
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
233
+ )
234
+
235
+ self.set_attn_processor(processor, _remove_lora=True)
236
+
237
+ @apply_forward_hook
238
+ def encode(
239
+ self, x: torch.FloatTensor, return_dict: bool = True
240
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
241
+ """
242
+ Encode a batch of images into latents.
243
+
244
+ Args:
245
+ x (`torch.FloatTensor`): Input batch of images.
246
+ return_dict (`bool`, *optional*, defaults to `True`):
247
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
248
+
249
+ Returns:
250
+ The latent representations of the encoded images. If `return_dict` is True, a
251
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
252
+ """
253
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
254
+ return self.tiled_encode(x, return_dict=return_dict)
255
+
256
+ if self.use_slicing and x.shape[0] > 1:
257
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
258
+ h = torch.cat(encoded_slices)
259
+ else:
260
+ h = self.encoder(x)
261
+
262
+ moments = self.quant_conv(h)
263
+ posterior = DiagonalGaussianDistribution(moments)
264
+
265
+ if not return_dict:
266
+ return (posterior,)
267
+
268
+ return AutoencoderKLOutput(latent_dist=posterior)
269
+
270
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
271
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
272
+ return self.tiled_decode(z, return_dict=return_dict)
273
+
274
+ z = self.post_quant_conv(z)
275
+ dec = self.decoder(z)
276
+
277
+ if not return_dict:
278
+ return (dec,)
279
+
280
+ return DecoderOutput(sample=dec)
281
+
282
+ @apply_forward_hook
283
+ def decode(
284
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
285
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
286
+ """
287
+ Decode a batch of images.
288
+
289
+ Args:
290
+ z (`torch.FloatTensor`): Input batch of latent vectors.
291
+ return_dict (`bool`, *optional*, defaults to `True`):
292
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
293
+
294
+ Returns:
295
+ [`~models.vae.DecoderOutput`] or `tuple`:
296
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
297
+ returned.
298
+
299
+ """
300
+ if self.use_slicing and z.shape[0] > 1:
301
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
302
+ decoded = torch.cat(decoded_slices)
303
+ else:
304
+ decoded = self._decode(z).sample
305
+
306
+ if not return_dict:
307
+ return (decoded,)
308
+
309
+ return DecoderOutput(sample=decoded)
310
+
311
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
312
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
313
+ for y in range(blend_extent):
314
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
315
+ return b
316
+
317
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
318
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
319
+ for x in range(blend_extent):
320
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
321
+ return b
322
+
323
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
324
+ r"""Encode a batch of images using a tiled encoder.
325
+
326
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
327
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
328
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
329
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
330
+ output, but they should be much less noticeable.
331
+
332
+ Args:
333
+ x (`torch.FloatTensor`): Input batch of images.
334
+ return_dict (`bool`, *optional*, defaults to `True`):
335
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
336
+
337
+ Returns:
338
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
339
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
340
+ `tuple` is returned.
341
+ """
342
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
343
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
344
+ row_limit = self.tile_latent_min_size - blend_extent
345
+
346
+ # Split the image into 512x512 tiles and encode them separately.
347
+ rows = []
348
+ for i in range(0, x.shape[2], overlap_size):
349
+ row = []
350
+ for j in range(0, x.shape[3], overlap_size):
351
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
352
+ tile = self.encoder(tile)
353
+ tile = self.quant_conv(tile)
354
+ row.append(tile)
355
+ rows.append(row)
356
+ result_rows = []
357
+ for i, row in enumerate(rows):
358
+ result_row = []
359
+ for j, tile in enumerate(row):
360
+ # blend the above tile and the left tile
361
+ # to the current tile and add the current tile to the result row
362
+ if i > 0:
363
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
364
+ if j > 0:
365
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
366
+ result_row.append(tile[:, :, :row_limit, :row_limit])
367
+ result_rows.append(torch.cat(result_row, dim=3))
368
+
369
+ moments = torch.cat(result_rows, dim=2)
370
+ posterior = DiagonalGaussianDistribution(moments)
371
+
372
+ if not return_dict:
373
+ return (posterior,)
374
+
375
+ return AutoencoderKLOutput(latent_dist=posterior)
376
+
377
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
378
+ r"""
379
+ Decode a batch of images using a tiled decoder.
380
+
381
+ Args:
382
+ z (`torch.FloatTensor`): Input batch of latent vectors.
383
+ return_dict (`bool`, *optional*, defaults to `True`):
384
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
385
+
386
+ Returns:
387
+ [`~models.vae.DecoderOutput`] or `tuple`:
388
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
389
+ returned.
390
+ """
391
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
392
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
393
+ row_limit = self.tile_sample_min_size - blend_extent
394
+
395
+ # Split z into overlapping 64x64 tiles and decode them separately.
396
+ # The tiles have an overlap to avoid seams between tiles.
397
+ rows = []
398
+ for i in range(0, z.shape[2], overlap_size):
399
+ row = []
400
+ for j in range(0, z.shape[3], overlap_size):
401
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
402
+ tile = self.post_quant_conv(tile)
403
+ decoded = self.decoder(tile)
404
+ row.append(decoded)
405
+ rows.append(row)
406
+ result_rows = []
407
+ for i, row in enumerate(rows):
408
+ result_row = []
409
+ for j, tile in enumerate(row):
410
+ # blend the above tile and the left tile
411
+ # to the current tile and add the current tile to the result row
412
+ if i > 0:
413
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
414
+ if j > 0:
415
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
416
+ result_row.append(tile[:, :, :row_limit, :row_limit])
417
+ result_rows.append(torch.cat(result_row, dim=3))
418
+
419
+ dec = torch.cat(result_rows, dim=2)
420
+ if not return_dict:
421
+ return (dec,)
422
+
423
+ return DecoderOutput(sample=dec)
424
+
425
+ def forward(
426
+ self,
427
+ sample: torch.FloatTensor,
428
+ sample_posterior: bool = False,
429
+ return_dict: bool = True,
430
+ generator: Optional[torch.Generator] = None,
431
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
432
+ r"""
433
+ Args:
434
+ sample (`torch.FloatTensor`): Input sample.
435
+ sample_posterior (`bool`, *optional*, defaults to `False`):
436
+ Whether to sample from the posterior.
437
+ return_dict (`bool`, *optional*, defaults to `True`):
438
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
439
+ """
440
+ x = sample
441
+ posterior = self.encode(x).latent_dist
442
+ if sample_posterior:
443
+ z = posterior.sample(generator=generator)
444
+ else:
445
+ z = posterior.mode()
446
+ dec = self.decode(z).sample
447
+
448
+ if not return_dict:
449
+ return (dec,)
450
+
451
+ return DecoderOutput(sample=dec)
452
+
453
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
454
+ def fuse_qkv_projections(self):
455
+ """
456
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
457
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
458
+
459
+ <Tip warning={true}>
460
+
461
+ This API is 🧪 experimental.
462
+
463
+ </Tip>
464
+ """
465
+ self.original_attn_processors = None
466
+
467
+ for _, attn_processor in self.attn_processors.items():
468
+ if "Added" in str(attn_processor.__class__.__name__):
469
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
470
+
471
+ self.original_attn_processors = self.attn_processors
472
+
473
+ for module in self.modules():
474
+ if isinstance(module, Attention):
475
+ module.fuse_projections(fuse=True)
476
+
477
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
478
+ def unfuse_qkv_projections(self):
479
+ """Disables the fused QKV projection if enabled.
480
+
481
+ <Tip warning={true}>
482
+
483
+ This API is 🧪 experimental.
484
+
485
+ </Tip>
486
+
487
+ """
488
+ if self.original_attn_processors is not None:
489
+ self.set_attn_processor(self.original_attn_processors)
module/diffusers_vae/vae.py ADDED
@@ -0,0 +1,985 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from diffusers.utils import BaseOutput, is_torch_version
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.models.activations import get_activation
24
+ from diffusers.models.attention_processor import SpatialNorm
25
+ from diffusers.models.unet_2d_blocks import (
26
+ AutoencoderTinyBlock,
27
+ UNetMidBlock2D,
28
+ get_down_block,
29
+ get_up_block,
30
+ )
31
+
32
+
33
+ @dataclass
34
+ class DecoderOutput(BaseOutput):
35
+ r"""
36
+ Output of decoding method.
37
+
38
+ Args:
39
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
40
+ The decoded output sample from the last layer of the model.
41
+ """
42
+
43
+ sample: torch.FloatTensor
44
+
45
+
46
+ class Encoder(nn.Module):
47
+ r"""
48
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
49
+
50
+ Args:
51
+ in_channels (`int`, *optional*, defaults to 3):
52
+ The number of input channels.
53
+ out_channels (`int`, *optional*, defaults to 3):
54
+ The number of output channels.
55
+ down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
56
+ The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
57
+ options.
58
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
59
+ The number of output channels for each block.
60
+ layers_per_block (`int`, *optional*, defaults to 2):
61
+ The number of layers per block.
62
+ norm_num_groups (`int`, *optional*, defaults to 32):
63
+ The number of groups for normalization.
64
+ act_fn (`str`, *optional*, defaults to `"silu"`):
65
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
66
+ double_z (`bool`, *optional*, defaults to `True`):
67
+ Whether to double the number of output channels for the last block.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ in_channels: int = 3,
73
+ out_channels: int = 3,
74
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
75
+ block_out_channels: Tuple[int, ...] = (64,),
76
+ layers_per_block: int = 2,
77
+ norm_num_groups: int = 32,
78
+ act_fn: str = "silu",
79
+ double_z: bool = True,
80
+ mid_block_add_attention=True,
81
+ ):
82
+ super().__init__()
83
+ self.layers_per_block = layers_per_block
84
+
85
+ self.conv_in = nn.Conv2d(
86
+ in_channels,
87
+ block_out_channels[0],
88
+ kernel_size=3,
89
+ stride=1,
90
+ padding=1,
91
+ )
92
+
93
+ self.mid_block = None
94
+ self.down_blocks = nn.ModuleList([])
95
+
96
+ # down
97
+ output_channel = block_out_channels[0]
98
+ for i, down_block_type in enumerate(down_block_types):
99
+ input_channel = output_channel
100
+ output_channel = block_out_channels[i]
101
+ is_final_block = i == len(block_out_channels) - 1
102
+
103
+ down_block = get_down_block(
104
+ down_block_type,
105
+ num_layers=self.layers_per_block,
106
+ in_channels=input_channel,
107
+ out_channels=output_channel,
108
+ add_downsample=not is_final_block,
109
+ resnet_eps=1e-6,
110
+ downsample_padding=0,
111
+ resnet_act_fn=act_fn,
112
+ resnet_groups=norm_num_groups,
113
+ attention_head_dim=output_channel,
114
+ temb_channels=None,
115
+ )
116
+ self.down_blocks.append(down_block)
117
+
118
+ # mid
119
+ self.mid_block = UNetMidBlock2D(
120
+ in_channels=block_out_channels[-1],
121
+ resnet_eps=1e-6,
122
+ resnet_act_fn=act_fn,
123
+ output_scale_factor=1,
124
+ resnet_time_scale_shift="default",
125
+ attention_head_dim=block_out_channels[-1],
126
+ resnet_groups=norm_num_groups,
127
+ temb_channels=None,
128
+ add_attention=mid_block_add_attention,
129
+ )
130
+
131
+ # out
132
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
133
+ self.conv_act = nn.SiLU()
134
+
135
+ conv_out_channels = 2 * out_channels if double_z else out_channels
136
+ self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
137
+
138
+ self.gradient_checkpointing = False
139
+
140
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
141
+ r"""The forward method of the `Encoder` class."""
142
+
143
+ sample = self.conv_in(sample)
144
+
145
+ if self.training and self.gradient_checkpointing:
146
+
147
+ def create_custom_forward(module):
148
+ def custom_forward(*inputs):
149
+ return module(*inputs)
150
+
151
+ return custom_forward
152
+
153
+ # down
154
+ if is_torch_version(">=", "1.11.0"):
155
+ for down_block in self.down_blocks:
156
+ sample = torch.utils.checkpoint.checkpoint(
157
+ create_custom_forward(down_block), sample, use_reentrant=False
158
+ )
159
+ # middle
160
+ sample = torch.utils.checkpoint.checkpoint(
161
+ create_custom_forward(self.mid_block), sample, use_reentrant=False
162
+ )
163
+ else:
164
+ for down_block in self.down_blocks:
165
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
166
+ # middle
167
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
168
+
169
+ else:
170
+ # down
171
+ for down_block in self.down_blocks:
172
+ sample = down_block(sample)
173
+
174
+ # middle
175
+ sample = self.mid_block(sample)
176
+
177
+ # post-process
178
+ sample = self.conv_norm_out(sample)
179
+ sample = self.conv_act(sample)
180
+ sample = self.conv_out(sample)
181
+
182
+ return sample
183
+
184
+
185
+ class Decoder(nn.Module):
186
+ r"""
187
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
188
+
189
+ Args:
190
+ in_channels (`int`, *optional*, defaults to 3):
191
+ The number of input channels.
192
+ out_channels (`int`, *optional*, defaults to 3):
193
+ The number of output channels.
194
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
195
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
196
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
197
+ The number of output channels for each block.
198
+ layers_per_block (`int`, *optional*, defaults to 2):
199
+ The number of layers per block.
200
+ norm_num_groups (`int`, *optional*, defaults to 32):
201
+ The number of groups for normalization.
202
+ act_fn (`str`, *optional*, defaults to `"silu"`):
203
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
204
+ norm_type (`str`, *optional*, defaults to `"group"`):
205
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ in_channels: int = 3,
211
+ out_channels: int = 3,
212
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
213
+ block_out_channels: Tuple[int, ...] = (64,),
214
+ layers_per_block: int = 2,
215
+ norm_num_groups: int = 32,
216
+ act_fn: str = "silu",
217
+ norm_type: str = "group", # group, spatial
218
+ mid_block_add_attention=True,
219
+ ):
220
+ super().__init__()
221
+ self.layers_per_block = layers_per_block
222
+
223
+ self.conv_in = nn.Conv2d(
224
+ in_channels,
225
+ block_out_channels[-1],
226
+ kernel_size=3,
227
+ stride=1,
228
+ padding=1,
229
+ )
230
+
231
+ self.mid_block = None
232
+ self.up_blocks = nn.ModuleList([])
233
+
234
+ temb_channels = in_channels if norm_type == "spatial" else None
235
+
236
+ # mid
237
+ self.mid_block = UNetMidBlock2D(
238
+ in_channels=block_out_channels[-1],
239
+ resnet_eps=1e-6,
240
+ resnet_act_fn=act_fn,
241
+ output_scale_factor=1,
242
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
243
+ attention_head_dim=block_out_channels[-1],
244
+ resnet_groups=norm_num_groups,
245
+ temb_channels=temb_channels,
246
+ add_attention=mid_block_add_attention,
247
+ )
248
+
249
+ # up
250
+ reversed_block_out_channels = list(reversed(block_out_channels))
251
+ output_channel = reversed_block_out_channels[0]
252
+ for i, up_block_type in enumerate(up_block_types):
253
+ prev_output_channel = output_channel
254
+ output_channel = reversed_block_out_channels[i]
255
+
256
+ is_final_block = i == len(block_out_channels) - 1
257
+
258
+ up_block = get_up_block(
259
+ up_block_type,
260
+ num_layers=self.layers_per_block + 1,
261
+ in_channels=prev_output_channel,
262
+ out_channels=output_channel,
263
+ prev_output_channel=None,
264
+ add_upsample=not is_final_block,
265
+ resnet_eps=1e-6,
266
+ resnet_act_fn=act_fn,
267
+ resnet_groups=norm_num_groups,
268
+ attention_head_dim=output_channel,
269
+ temb_channels=temb_channels,
270
+ resnet_time_scale_shift=norm_type,
271
+ )
272
+ self.up_blocks.append(up_block)
273
+ prev_output_channel = output_channel
274
+
275
+ # out
276
+ if norm_type == "spatial":
277
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
278
+ else:
279
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
280
+ self.conv_act = nn.SiLU()
281
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
282
+
283
+ self.gradient_checkpointing = False
284
+
285
+ def forward(
286
+ self,
287
+ sample: torch.FloatTensor,
288
+ latent_embeds: Optional[torch.FloatTensor] = None,
289
+ ) -> torch.FloatTensor:
290
+ r"""The forward method of the `Decoder` class."""
291
+
292
+ sample = self.conv_in(sample)
293
+ sample = sample.to(torch.float32)
294
+
295
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
296
+
297
+ if self.training and self.gradient_checkpointing:
298
+
299
+ def create_custom_forward(module):
300
+ def custom_forward(*inputs):
301
+ return module(*inputs)
302
+
303
+ return custom_forward
304
+
305
+ if is_torch_version(">=", "1.11.0"):
306
+ # middle
307
+ sample = torch.utils.checkpoint.checkpoint(
308
+ create_custom_forward(self.mid_block),
309
+ sample,
310
+ latent_embeds,
311
+ use_reentrant=False,
312
+ )
313
+ sample = sample.to(upscale_dtype)
314
+
315
+ # up
316
+ for up_block in self.up_blocks:
317
+ sample = torch.utils.checkpoint.checkpoint(
318
+ create_custom_forward(up_block),
319
+ sample,
320
+ latent_embeds,
321
+ use_reentrant=False,
322
+ )
323
+ else:
324
+ # middle
325
+ sample = torch.utils.checkpoint.checkpoint(
326
+ create_custom_forward(self.mid_block), sample, latent_embeds
327
+ )
328
+ sample = sample.to(upscale_dtype)
329
+
330
+ # up
331
+ for up_block in self.up_blocks:
332
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
333
+ else:
334
+ # middle
335
+ sample = self.mid_block(sample, latent_embeds)
336
+ sample = sample.to(upscale_dtype)
337
+
338
+ # up
339
+ for up_block in self.up_blocks:
340
+ sample = up_block(sample, latent_embeds)
341
+
342
+ # post-process
343
+ if latent_embeds is None:
344
+ sample = self.conv_norm_out(sample)
345
+ else:
346
+ sample = self.conv_norm_out(sample, latent_embeds)
347
+ sample = self.conv_act(sample)
348
+ sample = self.conv_out(sample)
349
+
350
+ return sample
351
+
352
+
353
+ class UpSample(nn.Module):
354
+ r"""
355
+ The `UpSample` layer of a variational autoencoder that upsamples its input.
356
+
357
+ Args:
358
+ in_channels (`int`, *optional*, defaults to 3):
359
+ The number of input channels.
360
+ out_channels (`int`, *optional*, defaults to 3):
361
+ The number of output channels.
362
+ """
363
+
364
+ def __init__(
365
+ self,
366
+ in_channels: int,
367
+ out_channels: int,
368
+ ) -> None:
369
+ super().__init__()
370
+ self.in_channels = in_channels
371
+ self.out_channels = out_channels
372
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
373
+
374
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
375
+ r"""The forward method of the `UpSample` class."""
376
+ x = torch.relu(x)
377
+ x = self.deconv(x)
378
+ return x
379
+
380
+
381
+ class MaskConditionEncoder(nn.Module):
382
+ """
383
+ used in AsymmetricAutoencoderKL
384
+ """
385
+
386
+ def __init__(
387
+ self,
388
+ in_ch: int,
389
+ out_ch: int = 192,
390
+ res_ch: int = 768,
391
+ stride: int = 16,
392
+ ) -> None:
393
+ super().__init__()
394
+
395
+ channels = []
396
+ while stride > 1:
397
+ stride = stride // 2
398
+ in_ch_ = out_ch * 2
399
+ if out_ch > res_ch:
400
+ out_ch = res_ch
401
+ if stride == 1:
402
+ in_ch_ = res_ch
403
+ channels.append((in_ch_, out_ch))
404
+ out_ch *= 2
405
+
406
+ out_channels = []
407
+ for _in_ch, _out_ch in channels:
408
+ out_channels.append(_out_ch)
409
+ out_channels.append(channels[-1][0])
410
+
411
+ layers = []
412
+ in_ch_ = in_ch
413
+ for l in range(len(out_channels)):
414
+ out_ch_ = out_channels[l]
415
+ if l == 0 or l == 1:
416
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=3, stride=1, padding=1))
417
+ else:
418
+ layers.append(nn.Conv2d(in_ch_, out_ch_, kernel_size=4, stride=2, padding=1))
419
+ in_ch_ = out_ch_
420
+
421
+ self.layers = nn.Sequential(*layers)
422
+
423
+ def forward(self, x: torch.FloatTensor, mask=None) -> torch.FloatTensor:
424
+ r"""The forward method of the `MaskConditionEncoder` class."""
425
+ out = {}
426
+ for l in range(len(self.layers)):
427
+ layer = self.layers[l]
428
+ x = layer(x)
429
+ out[str(tuple(x.shape))] = x
430
+ x = torch.relu(x)
431
+ return out
432
+
433
+
434
+ class MaskConditionDecoder(nn.Module):
435
+ r"""The `MaskConditionDecoder` should be used in combination with [`AsymmetricAutoencoderKL`] to enhance the model's
436
+ decoder with a conditioner on the mask and masked image.
437
+
438
+ Args:
439
+ in_channels (`int`, *optional*, defaults to 3):
440
+ The number of input channels.
441
+ out_channels (`int`, *optional*, defaults to 3):
442
+ The number of output channels.
443
+ up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
444
+ The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
445
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
446
+ The number of output channels for each block.
447
+ layers_per_block (`int`, *optional*, defaults to 2):
448
+ The number of layers per block.
449
+ norm_num_groups (`int`, *optional*, defaults to 32):
450
+ The number of groups for normalization.
451
+ act_fn (`str`, *optional*, defaults to `"silu"`):
452
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
453
+ norm_type (`str`, *optional*, defaults to `"group"`):
454
+ The normalization type to use. Can be either `"group"` or `"spatial"`.
455
+ """
456
+
457
+ def __init__(
458
+ self,
459
+ in_channels: int = 3,
460
+ out_channels: int = 3,
461
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
462
+ block_out_channels: Tuple[int, ...] = (64,),
463
+ layers_per_block: int = 2,
464
+ norm_num_groups: int = 32,
465
+ act_fn: str = "silu",
466
+ norm_type: str = "group", # group, spatial
467
+ ):
468
+ super().__init__()
469
+ self.layers_per_block = layers_per_block
470
+
471
+ self.conv_in = nn.Conv2d(
472
+ in_channels,
473
+ block_out_channels[-1],
474
+ kernel_size=3,
475
+ stride=1,
476
+ padding=1,
477
+ )
478
+
479
+ self.mid_block = None
480
+ self.up_blocks = nn.ModuleList([])
481
+
482
+ temb_channels = in_channels if norm_type == "spatial" else None
483
+
484
+ # mid
485
+ self.mid_block = UNetMidBlock2D(
486
+ in_channels=block_out_channels[-1],
487
+ resnet_eps=1e-6,
488
+ resnet_act_fn=act_fn,
489
+ output_scale_factor=1,
490
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
491
+ attention_head_dim=block_out_channels[-1],
492
+ resnet_groups=norm_num_groups,
493
+ temb_channels=temb_channels,
494
+ )
495
+
496
+ # up
497
+ reversed_block_out_channels = list(reversed(block_out_channels))
498
+ output_channel = reversed_block_out_channels[0]
499
+ for i, up_block_type in enumerate(up_block_types):
500
+ prev_output_channel = output_channel
501
+ output_channel = reversed_block_out_channels[i]
502
+
503
+ is_final_block = i == len(block_out_channels) - 1
504
+
505
+ up_block = get_up_block(
506
+ up_block_type,
507
+ num_layers=self.layers_per_block + 1,
508
+ in_channels=prev_output_channel,
509
+ out_channels=output_channel,
510
+ prev_output_channel=None,
511
+ add_upsample=not is_final_block,
512
+ resnet_eps=1e-6,
513
+ resnet_act_fn=act_fn,
514
+ resnet_groups=norm_num_groups,
515
+ attention_head_dim=output_channel,
516
+ temb_channels=temb_channels,
517
+ resnet_time_scale_shift=norm_type,
518
+ )
519
+ self.up_blocks.append(up_block)
520
+ prev_output_channel = output_channel
521
+
522
+ # condition encoder
523
+ self.condition_encoder = MaskConditionEncoder(
524
+ in_ch=out_channels,
525
+ out_ch=block_out_channels[0],
526
+ res_ch=block_out_channels[-1],
527
+ )
528
+
529
+ # out
530
+ if norm_type == "spatial":
531
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
532
+ else:
533
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
534
+ self.conv_act = nn.SiLU()
535
+ self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
536
+
537
+ self.gradient_checkpointing = False
538
+
539
+ def forward(
540
+ self,
541
+ z: torch.FloatTensor,
542
+ image: Optional[torch.FloatTensor] = None,
543
+ mask: Optional[torch.FloatTensor] = None,
544
+ latent_embeds: Optional[torch.FloatTensor] = None,
545
+ ) -> torch.FloatTensor:
546
+ r"""The forward method of the `MaskConditionDecoder` class."""
547
+ sample = z
548
+ sample = self.conv_in(sample)
549
+
550
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
551
+ if self.training and self.gradient_checkpointing:
552
+
553
+ def create_custom_forward(module):
554
+ def custom_forward(*inputs):
555
+ return module(*inputs)
556
+
557
+ return custom_forward
558
+
559
+ if is_torch_version(">=", "1.11.0"):
560
+ # middle
561
+ sample = torch.utils.checkpoint.checkpoint(
562
+ create_custom_forward(self.mid_block),
563
+ sample,
564
+ latent_embeds,
565
+ use_reentrant=False,
566
+ )
567
+ sample = sample.to(upscale_dtype)
568
+
569
+ # condition encoder
570
+ if image is not None and mask is not None:
571
+ masked_image = (1 - mask) * image
572
+ im_x = torch.utils.checkpoint.checkpoint(
573
+ create_custom_forward(self.condition_encoder),
574
+ masked_image,
575
+ mask,
576
+ use_reentrant=False,
577
+ )
578
+
579
+ # up
580
+ for up_block in self.up_blocks:
581
+ if image is not None and mask is not None:
582
+ sample_ = im_x[str(tuple(sample.shape))]
583
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
584
+ sample = sample * mask_ + sample_ * (1 - mask_)
585
+ sample = torch.utils.checkpoint.checkpoint(
586
+ create_custom_forward(up_block),
587
+ sample,
588
+ latent_embeds,
589
+ use_reentrant=False,
590
+ )
591
+ if image is not None and mask is not None:
592
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
593
+ else:
594
+ # middle
595
+ sample = torch.utils.checkpoint.checkpoint(
596
+ create_custom_forward(self.mid_block), sample, latent_embeds
597
+ )
598
+ sample = sample.to(upscale_dtype)
599
+
600
+ # condition encoder
601
+ if image is not None and mask is not None:
602
+ masked_image = (1 - mask) * image
603
+ im_x = torch.utils.checkpoint.checkpoint(
604
+ create_custom_forward(self.condition_encoder),
605
+ masked_image,
606
+ mask,
607
+ )
608
+
609
+ # up
610
+ for up_block in self.up_blocks:
611
+ if image is not None and mask is not None:
612
+ sample_ = im_x[str(tuple(sample.shape))]
613
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
614
+ sample = sample * mask_ + sample_ * (1 - mask_)
615
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
616
+ if image is not None and mask is not None:
617
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
618
+ else:
619
+ # middle
620
+ sample = self.mid_block(sample, latent_embeds)
621
+ sample = sample.to(upscale_dtype)
622
+
623
+ # condition encoder
624
+ if image is not None and mask is not None:
625
+ masked_image = (1 - mask) * image
626
+ im_x = self.condition_encoder(masked_image, mask)
627
+
628
+ # up
629
+ for up_block in self.up_blocks:
630
+ if image is not None and mask is not None:
631
+ sample_ = im_x[str(tuple(sample.shape))]
632
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
633
+ sample = sample * mask_ + sample_ * (1 - mask_)
634
+ sample = up_block(sample, latent_embeds)
635
+ if image is not None and mask is not None:
636
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
637
+
638
+ # post-process
639
+ if latent_embeds is None:
640
+ sample = self.conv_norm_out(sample)
641
+ else:
642
+ sample = self.conv_norm_out(sample, latent_embeds)
643
+ sample = self.conv_act(sample)
644
+ sample = self.conv_out(sample)
645
+
646
+ return sample
647
+
648
+
649
+ class VectorQuantizer(nn.Module):
650
+ """
651
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix
652
+ multiplications and allows for post-hoc remapping of indices.
653
+ """
654
+
655
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
656
+ # backwards compatibility we use the buggy version by default, but you can
657
+ # specify legacy=False to fix it.
658
+ def __init__(
659
+ self,
660
+ n_e: int,
661
+ vq_embed_dim: int,
662
+ beta: float,
663
+ remap=None,
664
+ unknown_index: str = "random",
665
+ sane_index_shape: bool = False,
666
+ legacy: bool = True,
667
+ ):
668
+ super().__init__()
669
+ self.n_e = n_e
670
+ self.vq_embed_dim = vq_embed_dim
671
+ self.beta = beta
672
+ self.legacy = legacy
673
+
674
+ self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
675
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
676
+
677
+ self.remap = remap
678
+ if self.remap is not None:
679
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
680
+ self.used: torch.Tensor
681
+ self.re_embed = self.used.shape[0]
682
+ self.unknown_index = unknown_index # "random" or "extra" or integer
683
+ if self.unknown_index == "extra":
684
+ self.unknown_index = self.re_embed
685
+ self.re_embed = self.re_embed + 1
686
+ print(
687
+ f"Remapping {self.n_e} indices to {self.re_embed} indices. "
688
+ f"Using {self.unknown_index} for unknown indices."
689
+ )
690
+ else:
691
+ self.re_embed = n_e
692
+
693
+ self.sane_index_shape = sane_index_shape
694
+
695
+ def remap_to_used(self, inds: torch.LongTensor) -> torch.LongTensor:
696
+ ishape = inds.shape
697
+ assert len(ishape) > 1
698
+ inds = inds.reshape(ishape[0], -1)
699
+ used = self.used.to(inds)
700
+ match = (inds[:, :, None] == used[None, None, ...]).long()
701
+ new = match.argmax(-1)
702
+ unknown = match.sum(2) < 1
703
+ if self.unknown_index == "random":
704
+ new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
705
+ else:
706
+ new[unknown] = self.unknown_index
707
+ return new.reshape(ishape)
708
+
709
+ def unmap_to_all(self, inds: torch.LongTensor) -> torch.LongTensor:
710
+ ishape = inds.shape
711
+ assert len(ishape) > 1
712
+ inds = inds.reshape(ishape[0], -1)
713
+ used = self.used.to(inds)
714
+ if self.re_embed > self.used.shape[0]: # extra token
715
+ inds[inds >= self.used.shape[0]] = 0 # simply set to zero
716
+ back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
717
+ return back.reshape(ishape)
718
+
719
+ def forward(self, z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, Tuple]:
720
+ # reshape z -> (batch, height, width, channel) and flatten
721
+ z = z.permute(0, 2, 3, 1).contiguous()
722
+ z_flattened = z.view(-1, self.vq_embed_dim)
723
+
724
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
725
+ min_encoding_indices = torch.argmin(torch.cdist(z_flattened, self.embedding.weight), dim=1)
726
+
727
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
728
+ perplexity = None
729
+ min_encodings = None
730
+
731
+ # compute loss for embedding
732
+ if not self.legacy:
733
+ loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
734
+ else:
735
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
736
+
737
+ # preserve gradients
738
+ z_q: torch.FloatTensor = z + (z_q - z).detach()
739
+
740
+ # reshape back to match original input shape
741
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
742
+
743
+ if self.remap is not None:
744
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
745
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
746
+ min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
747
+
748
+ if self.sane_index_shape:
749
+ min_encoding_indices = min_encoding_indices.reshape(z_q.shape[0], z_q.shape[2], z_q.shape[3])
750
+
751
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
752
+
753
+ def get_codebook_entry(self, indices: torch.LongTensor, shape: Tuple[int, ...]) -> torch.FloatTensor:
754
+ # shape specifying (batch, height, width, channel)
755
+ if self.remap is not None:
756
+ indices = indices.reshape(shape[0], -1) # add batch axis
757
+ indices = self.unmap_to_all(indices)
758
+ indices = indices.reshape(-1) # flatten again
759
+
760
+ # get quantized latent vectors
761
+ z_q: torch.FloatTensor = self.embedding(indices)
762
+
763
+ if shape is not None:
764
+ z_q = z_q.view(shape)
765
+ # reshape back to match original input shape
766
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
767
+
768
+ return z_q
769
+
770
+
771
+ class DiagonalGaussianDistribution(object):
772
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
773
+ self.parameters = parameters
774
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
775
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
776
+ self.deterministic = deterministic
777
+ self.std = torch.exp(0.5 * self.logvar)
778
+ self.var = torch.exp(self.logvar)
779
+ if self.deterministic:
780
+ self.var = self.std = torch.zeros_like(
781
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
782
+ )
783
+
784
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
785
+ # make sure sample is on the same device as the parameters and has same dtype
786
+ sample = randn_tensor(
787
+ self.mean.shape,
788
+ generator=generator,
789
+ device=self.parameters.device,
790
+ dtype=self.parameters.dtype,
791
+ )
792
+ x = self.mean + self.std * sample
793
+ return x
794
+
795
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
796
+ if self.deterministic:
797
+ return torch.Tensor([0.0])
798
+ else:
799
+ if other is None:
800
+ return 0.5 * torch.sum(
801
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
802
+ dim=[1, 2, 3],
803
+ )
804
+ else:
805
+ return 0.5 * torch.sum(
806
+ torch.pow(self.mean - other.mean, 2) / other.var
807
+ + self.var / other.var
808
+ - 1.0
809
+ - self.logvar
810
+ + other.logvar,
811
+ dim=[1, 2, 3],
812
+ )
813
+
814
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
815
+ if self.deterministic:
816
+ return torch.Tensor([0.0])
817
+ logtwopi = np.log(2.0 * np.pi)
818
+ return 0.5 * torch.sum(
819
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
820
+ dim=dims,
821
+ )
822
+
823
+ def mode(self) -> torch.Tensor:
824
+ return self.mean
825
+
826
+
827
+ class EncoderTiny(nn.Module):
828
+ r"""
829
+ The `EncoderTiny` layer is a simpler version of the `Encoder` layer.
830
+
831
+ Args:
832
+ in_channels (`int`):
833
+ The number of input channels.
834
+ out_channels (`int`):
835
+ The number of output channels.
836
+ num_blocks (`Tuple[int, ...]`):
837
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
838
+ use.
839
+ block_out_channels (`Tuple[int, ...]`):
840
+ The number of output channels for each block.
841
+ act_fn (`str`):
842
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
843
+ """
844
+
845
+ def __init__(
846
+ self,
847
+ in_channels: int,
848
+ out_channels: int,
849
+ num_blocks: Tuple[int, ...],
850
+ block_out_channels: Tuple[int, ...],
851
+ act_fn: str,
852
+ ):
853
+ super().__init__()
854
+
855
+ layers = []
856
+ for i, num_block in enumerate(num_blocks):
857
+ num_channels = block_out_channels[i]
858
+
859
+ if i == 0:
860
+ layers.append(nn.Conv2d(in_channels, num_channels, kernel_size=3, padding=1))
861
+ else:
862
+ layers.append(
863
+ nn.Conv2d(
864
+ num_channels,
865
+ num_channels,
866
+ kernel_size=3,
867
+ padding=1,
868
+ stride=2,
869
+ bias=False,
870
+ )
871
+ )
872
+
873
+ for _ in range(num_block):
874
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
875
+
876
+ layers.append(nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, padding=1))
877
+
878
+ self.layers = nn.Sequential(*layers)
879
+ self.gradient_checkpointing = False
880
+
881
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
882
+ r"""The forward method of the `EncoderTiny` class."""
883
+ if self.training and self.gradient_checkpointing:
884
+
885
+ def create_custom_forward(module):
886
+ def custom_forward(*inputs):
887
+ return module(*inputs)
888
+
889
+ return custom_forward
890
+
891
+ if is_torch_version(">=", "1.11.0"):
892
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
893
+ else:
894
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
895
+
896
+ else:
897
+ # scale image from [-1, 1] to [0, 1] to match TAESD convention
898
+ x = self.layers(x.add(1).div(2))
899
+
900
+ return x
901
+
902
+
903
+ class DecoderTiny(nn.Module):
904
+ r"""
905
+ The `DecoderTiny` layer is a simpler version of the `Decoder` layer.
906
+
907
+ Args:
908
+ in_channels (`int`):
909
+ The number of input channels.
910
+ out_channels (`int`):
911
+ The number of output channels.
912
+ num_blocks (`Tuple[int, ...]`):
913
+ Each value of the tuple represents a Conv2d layer followed by `value` number of `AutoencoderTinyBlock`'s to
914
+ use.
915
+ block_out_channels (`Tuple[int, ...]`):
916
+ The number of output channels for each block.
917
+ upsampling_scaling_factor (`int`):
918
+ The scaling factor to use for upsampling.
919
+ act_fn (`str`):
920
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
921
+ """
922
+
923
+ def __init__(
924
+ self,
925
+ in_channels: int,
926
+ out_channels: int,
927
+ num_blocks: Tuple[int, ...],
928
+ block_out_channels: Tuple[int, ...],
929
+ upsampling_scaling_factor: int,
930
+ act_fn: str,
931
+ ):
932
+ super().__init__()
933
+
934
+ layers = [
935
+ nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1),
936
+ get_activation(act_fn),
937
+ ]
938
+
939
+ for i, num_block in enumerate(num_blocks):
940
+ is_final_block = i == (len(num_blocks) - 1)
941
+ num_channels = block_out_channels[i]
942
+
943
+ for _ in range(num_block):
944
+ layers.append(AutoencoderTinyBlock(num_channels, num_channels, act_fn))
945
+
946
+ if not is_final_block:
947
+ layers.append(nn.Upsample(scale_factor=upsampling_scaling_factor))
948
+
949
+ conv_out_channel = num_channels if not is_final_block else out_channels
950
+ layers.append(
951
+ nn.Conv2d(
952
+ num_channels,
953
+ conv_out_channel,
954
+ kernel_size=3,
955
+ padding=1,
956
+ bias=is_final_block,
957
+ )
958
+ )
959
+
960
+ self.layers = nn.Sequential(*layers)
961
+ self.gradient_checkpointing = False
962
+
963
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
964
+ r"""The forward method of the `DecoderTiny` class."""
965
+ # Clamp.
966
+ x = torch.tanh(x / 3) * 3
967
+
968
+ if self.training and self.gradient_checkpointing:
969
+
970
+ def create_custom_forward(module):
971
+ def custom_forward(*inputs):
972
+ return module(*inputs)
973
+
974
+ return custom_forward
975
+
976
+ if is_torch_version(">=", "1.11.0"):
977
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
978
+ else:
979
+ x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
980
+
981
+ else:
982
+ x = self.layers(x)
983
+
984
+ # scale image from [0, 1] to [-1, 1] to match diffusers convention
985
+ return x.mul(2).sub(1)
module/ip_adapter/attention_processor.py ADDED
@@ -0,0 +1,1467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class AdaLayerNorm(nn.Module):
7
+ def __init__(self, embedding_dim: int, time_embedding_dim: int = None):
8
+ super().__init__()
9
+
10
+ if time_embedding_dim is None:
11
+ time_embedding_dim = embedding_dim
12
+
13
+ self.silu = nn.SiLU()
14
+ self.linear = nn.Linear(time_embedding_dim, 2 * embedding_dim, bias=True)
15
+ nn.init.zeros_(self.linear.weight)
16
+ nn.init.zeros_(self.linear.bias)
17
+
18
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
19
+
20
+ def forward(
21
+ self, x: torch.Tensor, timestep_embedding: torch.Tensor
22
+ ):
23
+ emb = self.linear(self.silu(timestep_embedding))
24
+ shift, scale = emb.view(len(x), 1, -1).chunk(2, dim=-1)
25
+ x = self.norm(x) * (1 + scale) + shift
26
+ return x
27
+
28
+
29
+ class AttnProcessor(nn.Module):
30
+ r"""
31
+ Default processor for performing attention-related computations.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ hidden_size=None,
37
+ cross_attention_dim=None,
38
+ ):
39
+ super().__init__()
40
+
41
+ def __call__(
42
+ self,
43
+ attn,
44
+ hidden_states,
45
+ encoder_hidden_states=None,
46
+ attention_mask=None,
47
+ temb=None,
48
+ ):
49
+ residual = hidden_states
50
+
51
+ if attn.spatial_norm is not None:
52
+ hidden_states = attn.spatial_norm(hidden_states, temb)
53
+
54
+ input_ndim = hidden_states.ndim
55
+
56
+ if input_ndim == 4:
57
+ batch_size, channel, height, width = hidden_states.shape
58
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
59
+
60
+ batch_size, sequence_length, _ = (
61
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
62
+ )
63
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
64
+
65
+ if attn.group_norm is not None:
66
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
67
+
68
+ query = attn.to_q(hidden_states)
69
+
70
+ if encoder_hidden_states is None:
71
+ encoder_hidden_states = hidden_states
72
+ elif attn.norm_cross:
73
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
74
+
75
+ key = attn.to_k(encoder_hidden_states)
76
+ value = attn.to_v(encoder_hidden_states)
77
+
78
+ query = attn.head_to_batch_dim(query)
79
+ key = attn.head_to_batch_dim(key)
80
+ value = attn.head_to_batch_dim(value)
81
+
82
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
83
+ hidden_states = torch.bmm(attention_probs, value)
84
+ hidden_states = attn.batch_to_head_dim(hidden_states)
85
+
86
+ # linear proj
87
+ hidden_states = attn.to_out[0](hidden_states)
88
+ # dropout
89
+ hidden_states = attn.to_out[1](hidden_states)
90
+
91
+ if input_ndim == 4:
92
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
93
+
94
+ if attn.residual_connection:
95
+ hidden_states = hidden_states + residual
96
+
97
+ hidden_states = hidden_states / attn.rescale_output_factor
98
+
99
+ return hidden_states
100
+
101
+
102
+ class IPAttnProcessor(nn.Module):
103
+ r"""
104
+ Attention processor for IP-Adapater.
105
+ Args:
106
+ hidden_size (`int`):
107
+ The hidden size of the attention layer.
108
+ cross_attention_dim (`int`):
109
+ The number of channels in the `encoder_hidden_states`.
110
+ scale (`float`, defaults to 1.0):
111
+ the weight scale of image prompt.
112
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
113
+ The context length of the image features.
114
+ """
115
+
116
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
117
+ super().__init__()
118
+
119
+ self.hidden_size = hidden_size
120
+ self.cross_attention_dim = cross_attention_dim
121
+ self.scale = scale
122
+ self.num_tokens = num_tokens
123
+
124
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
125
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
126
+
127
+ def __call__(
128
+ self,
129
+ attn,
130
+ hidden_states,
131
+ encoder_hidden_states=None,
132
+ attention_mask=None,
133
+ temb=None,
134
+ ):
135
+ residual = hidden_states
136
+
137
+ if attn.spatial_norm is not None:
138
+ hidden_states = attn.spatial_norm(hidden_states, temb)
139
+
140
+ input_ndim = hidden_states.ndim
141
+
142
+ if input_ndim == 4:
143
+ batch_size, channel, height, width = hidden_states.shape
144
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
145
+
146
+ batch_size, sequence_length, _ = (
147
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
148
+ )
149
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
150
+
151
+ if attn.group_norm is not None:
152
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
153
+
154
+ query = attn.to_q(hidden_states)
155
+
156
+ if encoder_hidden_states is None:
157
+ encoder_hidden_states = hidden_states
158
+ else:
159
+ # get encoder_hidden_states, ip_hidden_states
160
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
161
+ encoder_hidden_states, ip_hidden_states = (
162
+ encoder_hidden_states[:, :end_pos, :],
163
+ encoder_hidden_states[:, end_pos:, :],
164
+ )
165
+ if attn.norm_cross:
166
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
167
+
168
+ key = attn.to_k(encoder_hidden_states)
169
+ value = attn.to_v(encoder_hidden_states)
170
+
171
+ query = attn.head_to_batch_dim(query)
172
+ key = attn.head_to_batch_dim(key)
173
+ value = attn.head_to_batch_dim(value)
174
+
175
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
176
+ hidden_states = torch.bmm(attention_probs, value)
177
+ hidden_states = attn.batch_to_head_dim(hidden_states)
178
+
179
+ # for ip-adapter
180
+ ip_key = self.to_k_ip(ip_hidden_states)
181
+ ip_value = self.to_v_ip(ip_hidden_states)
182
+
183
+ ip_key = attn.head_to_batch_dim(ip_key)
184
+ ip_value = attn.head_to_batch_dim(ip_value)
185
+
186
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
187
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
188
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
189
+
190
+ hidden_states = hidden_states + self.scale * ip_hidden_states
191
+
192
+ # linear proj
193
+ hidden_states = attn.to_out[0](hidden_states)
194
+ # dropout
195
+ hidden_states = attn.to_out[1](hidden_states)
196
+
197
+ if input_ndim == 4:
198
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
199
+
200
+ if attn.residual_connection:
201
+ hidden_states = hidden_states + residual
202
+
203
+ hidden_states = hidden_states / attn.rescale_output_factor
204
+
205
+ return hidden_states
206
+
207
+
208
+ class TA_IPAttnProcessor(nn.Module):
209
+ r"""
210
+ Attention processor for IP-Adapater.
211
+ Args:
212
+ hidden_size (`int`):
213
+ The hidden size of the attention layer.
214
+ cross_attention_dim (`int`):
215
+ The number of channels in the `encoder_hidden_states`.
216
+ scale (`float`, defaults to 1.0):
217
+ the weight scale of image prompt.
218
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
219
+ The context length of the image features.
220
+ """
221
+
222
+ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
223
+ super().__init__()
224
+
225
+ self.hidden_size = hidden_size
226
+ self.cross_attention_dim = cross_attention_dim
227
+ self.scale = scale
228
+ self.num_tokens = num_tokens
229
+
230
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
231
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
232
+
233
+ self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
234
+ self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
235
+
236
+ def __call__(
237
+ self,
238
+ attn,
239
+ hidden_states,
240
+ encoder_hidden_states=None,
241
+ attention_mask=None,
242
+ temb=None,
243
+ ):
244
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
245
+
246
+ residual = hidden_states
247
+
248
+ if attn.spatial_norm is not None:
249
+ hidden_states = attn.spatial_norm(hidden_states, temb)
250
+
251
+ input_ndim = hidden_states.ndim
252
+
253
+ if input_ndim == 4:
254
+ batch_size, channel, height, width = hidden_states.shape
255
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
256
+
257
+ batch_size, sequence_length, _ = (
258
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
259
+ )
260
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
261
+
262
+ if attn.group_norm is not None:
263
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
264
+
265
+ query = attn.to_q(hidden_states)
266
+
267
+ if encoder_hidden_states is None:
268
+ encoder_hidden_states = hidden_states
269
+ else:
270
+ # get encoder_hidden_states, ip_hidden_states
271
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
272
+ encoder_hidden_states, ip_hidden_states = (
273
+ encoder_hidden_states[:, :end_pos, :],
274
+ encoder_hidden_states[:, end_pos:, :],
275
+ )
276
+ if attn.norm_cross:
277
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
278
+
279
+ key = attn.to_k(encoder_hidden_states)
280
+ value = attn.to_v(encoder_hidden_states)
281
+
282
+ query = attn.head_to_batch_dim(query)
283
+ key = attn.head_to_batch_dim(key)
284
+ value = attn.head_to_batch_dim(value)
285
+
286
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
287
+ hidden_states = torch.bmm(attention_probs, value)
288
+ hidden_states = attn.batch_to_head_dim(hidden_states)
289
+
290
+ # for ip-adapter
291
+ ip_key = self.to_k_ip(ip_hidden_states)
292
+ ip_value = self.to_v_ip(ip_hidden_states)
293
+
294
+ # time-dependent adaLN
295
+ ip_key = self.ln_k_ip(ip_key, temb)
296
+ ip_value = self.ln_v_ip(ip_value, temb)
297
+
298
+ ip_key = attn.head_to_batch_dim(ip_key)
299
+ ip_value = attn.head_to_batch_dim(ip_value)
300
+
301
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
302
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
303
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
304
+
305
+ hidden_states = hidden_states + self.scale * ip_hidden_states
306
+
307
+ # linear proj
308
+ hidden_states = attn.to_out[0](hidden_states)
309
+ # dropout
310
+ hidden_states = attn.to_out[1](hidden_states)
311
+
312
+ if input_ndim == 4:
313
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
314
+
315
+ if attn.residual_connection:
316
+ hidden_states = hidden_states + residual
317
+
318
+ hidden_states = hidden_states / attn.rescale_output_factor
319
+
320
+ return hidden_states
321
+
322
+
323
+ class AttnProcessor2_0(torch.nn.Module):
324
+ r"""
325
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
326
+ """
327
+
328
+ def __init__(
329
+ self,
330
+ hidden_size=None,
331
+ cross_attention_dim=None,
332
+ ):
333
+ super().__init__()
334
+ if not hasattr(F, "scaled_dot_product_attention"):
335
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
336
+
337
+ def __call__(
338
+ self,
339
+ attn,
340
+ hidden_states,
341
+ encoder_hidden_states=None,
342
+ attention_mask=None,
343
+ external_kv=None,
344
+ temb=None,
345
+ ):
346
+ residual = hidden_states
347
+
348
+ if attn.spatial_norm is not None:
349
+ hidden_states = attn.spatial_norm(hidden_states, temb)
350
+
351
+ input_ndim = hidden_states.ndim
352
+
353
+ if input_ndim == 4:
354
+ batch_size, channel, height, width = hidden_states.shape
355
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
356
+
357
+ batch_size, sequence_length, _ = (
358
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
359
+ )
360
+
361
+ if attention_mask is not None:
362
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
363
+ # scaled_dot_product_attention expects attention_mask shape to be
364
+ # (batch, heads, source_length, target_length)
365
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
366
+
367
+ if attn.group_norm is not None:
368
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
369
+
370
+ query = attn.to_q(hidden_states)
371
+
372
+ if encoder_hidden_states is None:
373
+ encoder_hidden_states = hidden_states
374
+ elif attn.norm_cross:
375
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
376
+
377
+ key = attn.to_k(encoder_hidden_states)
378
+ value = attn.to_v(encoder_hidden_states)
379
+
380
+ if external_kv:
381
+ key = torch.cat([key, external_kv.k], axis=1)
382
+ value = torch.cat([value, external_kv.v], axis=1)
383
+
384
+ inner_dim = key.shape[-1]
385
+ head_dim = inner_dim // attn.heads
386
+
387
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
388
+
389
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
390
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
391
+
392
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
393
+ # TODO: add support for attn.scale when we move to Torch 2.1
394
+ hidden_states = F.scaled_dot_product_attention(
395
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
396
+ )
397
+
398
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
399
+ hidden_states = hidden_states.to(query.dtype)
400
+
401
+ # linear proj
402
+ hidden_states = attn.to_out[0](hidden_states)
403
+ # dropout
404
+ hidden_states = attn.to_out[1](hidden_states)
405
+
406
+ if input_ndim == 4:
407
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
408
+
409
+ if attn.residual_connection:
410
+ hidden_states = hidden_states + residual
411
+
412
+ hidden_states = hidden_states / attn.rescale_output_factor
413
+
414
+ return hidden_states
415
+
416
+
417
+ class split_AttnProcessor2_0(torch.nn.Module):
418
+ r"""
419
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ hidden_size=None,
425
+ cross_attention_dim=None,
426
+ time_embedding_dim=None,
427
+ ):
428
+ super().__init__()
429
+ if not hasattr(F, "scaled_dot_product_attention"):
430
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
431
+
432
+ def __call__(
433
+ self,
434
+ attn,
435
+ hidden_states,
436
+ encoder_hidden_states=None,
437
+ attention_mask=None,
438
+ external_kv=None,
439
+ temb=None,
440
+ cat_dim=-2,
441
+ original_shape=None,
442
+ ):
443
+ residual = hidden_states
444
+
445
+ if attn.spatial_norm is not None:
446
+ hidden_states = attn.spatial_norm(hidden_states, temb)
447
+
448
+ input_ndim = hidden_states.ndim
449
+
450
+ if input_ndim == 4:
451
+ # 2d to sequence.
452
+ height, width = hidden_states.shape[-2:]
453
+ if cat_dim==-2 or cat_dim==2:
454
+ hidden_states_0 = hidden_states[:, :, :height//2, :]
455
+ hidden_states_1 = hidden_states[:, :, -(height//2):, :]
456
+ elif cat_dim==-1 or cat_dim==3:
457
+ hidden_states_0 = hidden_states[:, :, :, :width//2]
458
+ hidden_states_1 = hidden_states[:, :, :, -(width//2):]
459
+ batch_size, channel, height, width = hidden_states_0.shape
460
+ hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
461
+ hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
462
+ else:
463
+ # directly split sqeuence according to concat dim.
464
+ single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
465
+ hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
466
+ hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
467
+
468
+ hidden_states = torch.cat([hidden_states_0, hidden_states_1], dim=1)
469
+ batch_size, sequence_length, _ = (
470
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
471
+ )
472
+
473
+ if attention_mask is not None:
474
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
475
+ # scaled_dot_product_attention expects attention_mask shape to be
476
+ # (batch, heads, source_length, target_length)
477
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
478
+
479
+ if attn.group_norm is not None:
480
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
481
+
482
+ query = attn.to_q(hidden_states)
483
+ key = attn.to_k(hidden_states)
484
+ value = attn.to_v(hidden_states)
485
+
486
+ if external_kv:
487
+ key = torch.cat([key, external_kv.k], dim=1)
488
+ value = torch.cat([value, external_kv.v], dim=1)
489
+
490
+ inner_dim = key.shape[-1]
491
+ head_dim = inner_dim // attn.heads
492
+
493
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
494
+
495
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
496
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
497
+
498
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
499
+ # TODO: add support for attn.scale when we move to Torch 2.1
500
+ hidden_states = F.scaled_dot_product_attention(
501
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
502
+ )
503
+
504
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
505
+ hidden_states = hidden_states.to(query.dtype)
506
+
507
+ # linear proj
508
+ hidden_states = attn.to_out[0](hidden_states)
509
+ # dropout
510
+ hidden_states = attn.to_out[1](hidden_states)
511
+
512
+ # spatially split.
513
+ hidden_states_0, hidden_states_1 = hidden_states.chunk(2, dim=1)
514
+
515
+ if input_ndim == 4:
516
+ hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
517
+ hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
518
+
519
+ if cat_dim==-2 or cat_dim==2:
520
+ hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
521
+ elif cat_dim==-1 or cat_dim==3:
522
+ hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
523
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
524
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
525
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
526
+ else:
527
+ batch_size, sequence_length, inner_dim = hidden_states.shape
528
+ hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
529
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
530
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
531
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
532
+
533
+ if attn.residual_connection:
534
+ hidden_states = hidden_states + residual
535
+
536
+ hidden_states = hidden_states / attn.rescale_output_factor
537
+
538
+ return hidden_states
539
+
540
+
541
+ class sep_split_AttnProcessor2_0(torch.nn.Module):
542
+ r"""
543
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
544
+ """
545
+
546
+ def __init__(
547
+ self,
548
+ hidden_size=None,
549
+ cross_attention_dim=None,
550
+ time_embedding_dim=None,
551
+ ):
552
+ super().__init__()
553
+ if not hasattr(F, "scaled_dot_product_attention"):
554
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
555
+ self.ln_k_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
556
+ self.ln_v_ref = AdaLayerNorm(hidden_size, time_embedding_dim)
557
+ # self.hidden_size = hidden_size
558
+ # self.cross_attention_dim = cross_attention_dim
559
+ # self.scale = scale
560
+ # self.num_tokens = num_tokens
561
+
562
+ # self.to_q_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
563
+ # self.to_k_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
564
+ # self.to_v_ref = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
565
+
566
+ def __call__(
567
+ self,
568
+ attn,
569
+ hidden_states,
570
+ encoder_hidden_states=None,
571
+ attention_mask=None,
572
+ external_kv=None,
573
+ temb=None,
574
+ cat_dim=-2,
575
+ original_shape=None,
576
+ ref_scale=1.0,
577
+ ):
578
+ residual = hidden_states
579
+
580
+ if attn.spatial_norm is not None:
581
+ hidden_states = attn.spatial_norm(hidden_states, temb)
582
+
583
+ input_ndim = hidden_states.ndim
584
+
585
+ if input_ndim == 4:
586
+ # 2d to sequence.
587
+ height, width = hidden_states.shape[-2:]
588
+ if cat_dim==-2 or cat_dim==2:
589
+ hidden_states_0 = hidden_states[:, :, :height//2, :]
590
+ hidden_states_1 = hidden_states[:, :, -(height//2):, :]
591
+ elif cat_dim==-1 or cat_dim==3:
592
+ hidden_states_0 = hidden_states[:, :, :, :width//2]
593
+ hidden_states_1 = hidden_states[:, :, :, -(width//2):]
594
+ batch_size, channel, height, width = hidden_states_0.shape
595
+ hidden_states_0 = hidden_states_0.view(batch_size, channel, height * width).transpose(1, 2)
596
+ hidden_states_1 = hidden_states_1.view(batch_size, channel, height * width).transpose(1, 2)
597
+ else:
598
+ # directly split sqeuence according to concat dim.
599
+ single_dim = original_shape[2] if cat_dim==-2 or cat_dim==2 else original_shape[1]
600
+ hidden_states_0 = hidden_states[:, :single_dim*single_dim,:]
601
+ hidden_states_1 = hidden_states[:, single_dim*(single_dim+1):,:]
602
+
603
+ batch_size, sequence_length, _ = (
604
+ hidden_states_0.shape if encoder_hidden_states is None else encoder_hidden_states.shape
605
+ )
606
+
607
+ if attention_mask is not None:
608
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
609
+ # scaled_dot_product_attention expects attention_mask shape to be
610
+ # (batch, heads, source_length, target_length)
611
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
612
+
613
+ if attn.group_norm is not None:
614
+ hidden_states_0 = attn.group_norm(hidden_states_0.transpose(1, 2)).transpose(1, 2)
615
+ hidden_states_1 = attn.group_norm(hidden_states_1.transpose(1, 2)).transpose(1, 2)
616
+
617
+ query_0 = attn.to_q(hidden_states_0)
618
+ query_1 = attn.to_q(hidden_states_1)
619
+ key_0 = attn.to_k(hidden_states_0)
620
+ key_1 = attn.to_k(hidden_states_1)
621
+ value_0 = attn.to_v(hidden_states_0)
622
+ value_1 = attn.to_v(hidden_states_1)
623
+
624
+ # time-dependent adaLN
625
+ key_1 = self.ln_k_ref(key_1, temb)
626
+ value_1 = self.ln_v_ref(value_1, temb)
627
+
628
+ if external_kv:
629
+ key_1 = torch.cat([key_1, external_kv.k], dim=1)
630
+ value_1 = torch.cat([value_1, external_kv.v], dim=1)
631
+
632
+ inner_dim = key_0.shape[-1]
633
+ head_dim = inner_dim // attn.heads
634
+
635
+ query_0 = query_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
636
+ query_1 = query_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
637
+ key_0 = key_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
638
+ key_1 = key_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
639
+ value_0 = value_0.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
640
+ value_1 = value_1.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
641
+
642
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
643
+ # TODO: add support for attn.scale when we move to Torch 2.1
644
+ hidden_states_0 = F.scaled_dot_product_attention(
645
+ query_0, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
646
+ )
647
+ hidden_states_1 = F.scaled_dot_product_attention(
648
+ query_1, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
649
+ )
650
+
651
+ # cross-attn
652
+ _hidden_states_0 = F.scaled_dot_product_attention(
653
+ query_0, key_1, value_1, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
654
+ )
655
+ hidden_states_0 = hidden_states_0 + ref_scale * _hidden_states_0 * 10
656
+
657
+ # TODO: drop this cross-attn
658
+ _hidden_states_1 = F.scaled_dot_product_attention(
659
+ query_1, key_0, value_0, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
660
+ )
661
+ hidden_states_1 = hidden_states_1 + ref_scale * _hidden_states_1
662
+
663
+ hidden_states_0 = hidden_states_0.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
664
+ hidden_states_1 = hidden_states_1.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
665
+ hidden_states_0 = hidden_states_0.to(query_0.dtype)
666
+ hidden_states_1 = hidden_states_1.to(query_1.dtype)
667
+
668
+
669
+ # linear proj
670
+ hidden_states_0 = attn.to_out[0](hidden_states_0)
671
+ hidden_states_1 = attn.to_out[0](hidden_states_1)
672
+ # dropout
673
+ hidden_states_0 = attn.to_out[1](hidden_states_0)
674
+ hidden_states_1 = attn.to_out[1](hidden_states_1)
675
+
676
+
677
+ if input_ndim == 4:
678
+ hidden_states_0 = hidden_states_0.transpose(-1, -2).reshape(batch_size, channel, height, width)
679
+ hidden_states_1 = hidden_states_1.transpose(-1, -2).reshape(batch_size, channel, height, width)
680
+
681
+ if cat_dim==-2 or cat_dim==2:
682
+ hidden_states_pad = torch.zeros(batch_size, channel, 1, width)
683
+ elif cat_dim==-1 or cat_dim==3:
684
+ hidden_states_pad = torch.zeros(batch_size, channel, height, 1)
685
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
686
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=cat_dim)
687
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
688
+ else:
689
+ batch_size, sequence_length, inner_dim = hidden_states.shape
690
+ hidden_states_pad = torch.zeros(batch_size, single_dim, inner_dim)
691
+ hidden_states_pad = hidden_states_pad.to(hidden_states_0.device, dtype=hidden_states_0.dtype)
692
+ hidden_states = torch.cat([hidden_states_0, hidden_states_pad, hidden_states_1], dim=1)
693
+ assert hidden_states.shape == residual.shape, f"{hidden_states.shape} != {residual.shape}"
694
+
695
+ if attn.residual_connection:
696
+ hidden_states = hidden_states + residual
697
+
698
+ hidden_states = hidden_states / attn.rescale_output_factor
699
+
700
+ return hidden_states
701
+
702
+
703
+ class AdditiveKV_AttnProcessor2_0(torch.nn.Module):
704
+ r"""
705
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
706
+ """
707
+
708
+ def __init__(
709
+ self,
710
+ hidden_size: int = None,
711
+ cross_attention_dim: int = None,
712
+ time_embedding_dim: int = None,
713
+ additive_scale: float = 1.0,
714
+ ):
715
+ super().__init__()
716
+ if not hasattr(F, "scaled_dot_product_attention"):
717
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
718
+ self.additive_scale = additive_scale
719
+
720
+ def __call__(
721
+ self,
722
+ attn,
723
+ hidden_states,
724
+ encoder_hidden_states=None,
725
+ external_kv=None,
726
+ attention_mask=None,
727
+ temb=None,
728
+ ):
729
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
730
+
731
+ residual = hidden_states
732
+
733
+ if attn.spatial_norm is not None:
734
+ hidden_states = attn.spatial_norm(hidden_states, temb)
735
+
736
+ input_ndim = hidden_states.ndim
737
+
738
+ if input_ndim == 4:
739
+ batch_size, channel, height, width = hidden_states.shape
740
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
741
+
742
+ batch_size, sequence_length, _ = (
743
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
744
+ )
745
+
746
+ if attention_mask is not None:
747
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
748
+ # scaled_dot_product_attention expects attention_mask shape to be
749
+ # (batch, heads, source_length, target_length)
750
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
751
+
752
+ if attn.group_norm is not None:
753
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
754
+
755
+ query = attn.to_q(hidden_states)
756
+
757
+ if encoder_hidden_states is None:
758
+ encoder_hidden_states = hidden_states
759
+ elif attn.norm_cross:
760
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
761
+
762
+ key = attn.to_k(encoder_hidden_states)
763
+ value = attn.to_v(encoder_hidden_states)
764
+
765
+ inner_dim = key.shape[-1]
766
+ head_dim = inner_dim // attn.heads
767
+
768
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
769
+
770
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
771
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
772
+
773
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
774
+ # TODO: add support for attn.scale when we move to Torch 2.1
775
+ hidden_states = F.scaled_dot_product_attention(
776
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
777
+ )
778
+
779
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
780
+
781
+ if external_kv:
782
+ key = external_kv.k
783
+ value = external_kv.v
784
+
785
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
786
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
787
+
788
+ external_attn_output = F.scaled_dot_product_attention(
789
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
790
+ )
791
+
792
+ external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
793
+ hidden_states = hidden_states + self.additive_scale * external_attn_output
794
+
795
+ hidden_states = hidden_states.to(query.dtype)
796
+
797
+ # linear proj
798
+ hidden_states = attn.to_out[0](hidden_states)
799
+ # dropout
800
+ hidden_states = attn.to_out[1](hidden_states)
801
+
802
+ if input_ndim == 4:
803
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
804
+
805
+ if attn.residual_connection:
806
+ hidden_states = hidden_states + residual
807
+
808
+ hidden_states = hidden_states / attn.rescale_output_factor
809
+
810
+ return hidden_states
811
+
812
+
813
+ class TA_AdditiveKV_AttnProcessor2_0(torch.nn.Module):
814
+ r"""
815
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
816
+ """
817
+
818
+ def __init__(
819
+ self,
820
+ hidden_size: int = None,
821
+ cross_attention_dim: int = None,
822
+ time_embedding_dim: int = None,
823
+ additive_scale: float = 1.0,
824
+ ):
825
+ super().__init__()
826
+ if not hasattr(F, "scaled_dot_product_attention"):
827
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
828
+ self.ln_k = AdaLayerNorm(hidden_size, time_embedding_dim)
829
+ self.ln_v = AdaLayerNorm(hidden_size, time_embedding_dim)
830
+ self.additive_scale = additive_scale
831
+
832
+ def __call__(
833
+ self,
834
+ attn,
835
+ hidden_states,
836
+ encoder_hidden_states=None,
837
+ external_kv=None,
838
+ attention_mask=None,
839
+ temb=None,
840
+ ):
841
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
842
+
843
+ residual = hidden_states
844
+
845
+ if attn.spatial_norm is not None:
846
+ hidden_states = attn.spatial_norm(hidden_states, temb)
847
+
848
+ input_ndim = hidden_states.ndim
849
+
850
+ if input_ndim == 4:
851
+ batch_size, channel, height, width = hidden_states.shape
852
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
853
+
854
+ batch_size, sequence_length, _ = (
855
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
856
+ )
857
+
858
+ if attention_mask is not None:
859
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
860
+ # scaled_dot_product_attention expects attention_mask shape to be
861
+ # (batch, heads, source_length, target_length)
862
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
863
+
864
+ if attn.group_norm is not None:
865
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
866
+
867
+ query = attn.to_q(hidden_states)
868
+
869
+ if encoder_hidden_states is None:
870
+ encoder_hidden_states = hidden_states
871
+ elif attn.norm_cross:
872
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
873
+
874
+ key = attn.to_k(encoder_hidden_states)
875
+ value = attn.to_v(encoder_hidden_states)
876
+
877
+ inner_dim = key.shape[-1]
878
+ head_dim = inner_dim // attn.heads
879
+
880
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
881
+
882
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
883
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
884
+
885
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
886
+ # TODO: add support for attn.scale when we move to Torch 2.1
887
+ hidden_states = F.scaled_dot_product_attention(
888
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
889
+ )
890
+
891
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
892
+
893
+ if external_kv:
894
+ key = external_kv.k
895
+ value = external_kv.v
896
+
897
+ # time-dependent adaLN
898
+ key = self.ln_k(key, temb)
899
+ value = self.ln_v(value, temb)
900
+
901
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
902
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
903
+
904
+ external_attn_output = F.scaled_dot_product_attention(
905
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
906
+ )
907
+
908
+ external_attn_output = external_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
909
+ hidden_states = hidden_states + self.additive_scale * external_attn_output
910
+
911
+ hidden_states = hidden_states.to(query.dtype)
912
+
913
+ # linear proj
914
+ hidden_states = attn.to_out[0](hidden_states)
915
+ # dropout
916
+ hidden_states = attn.to_out[1](hidden_states)
917
+
918
+ if input_ndim == 4:
919
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
920
+
921
+ if attn.residual_connection:
922
+ hidden_states = hidden_states + residual
923
+
924
+ hidden_states = hidden_states / attn.rescale_output_factor
925
+
926
+ return hidden_states
927
+
928
+
929
+ class IPAttnProcessor2_0(torch.nn.Module):
930
+ r"""
931
+ Attention processor for IP-Adapater for PyTorch 2.0.
932
+ Args:
933
+ hidden_size (`int`):
934
+ The hidden size of the attention layer.
935
+ cross_attention_dim (`int`):
936
+ The number of channels in the `encoder_hidden_states`.
937
+ scale (`float`, defaults to 1.0):
938
+ the weight scale of image prompt.
939
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
940
+ The context length of the image features.
941
+ """
942
+
943
+ def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
944
+ super().__init__()
945
+
946
+ if not hasattr(F, "scaled_dot_product_attention"):
947
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
948
+
949
+ self.hidden_size = hidden_size
950
+ self.cross_attention_dim = cross_attention_dim
951
+ self.scale = scale
952
+ self.num_tokens = num_tokens
953
+
954
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
955
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
956
+
957
+ def __call__(
958
+ self,
959
+ attn,
960
+ hidden_states,
961
+ encoder_hidden_states=None,
962
+ attention_mask=None,
963
+ temb=None,
964
+ ):
965
+ residual = hidden_states
966
+
967
+ if attn.spatial_norm is not None:
968
+ hidden_states = attn.spatial_norm(hidden_states, temb)
969
+
970
+ input_ndim = hidden_states.ndim
971
+
972
+ if input_ndim == 4:
973
+ batch_size, channel, height, width = hidden_states.shape
974
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
975
+
976
+ if isinstance(encoder_hidden_states, tuple):
977
+ # FIXME: now hard coded to single image prompt.
978
+ batch_size, _, hid_dim = encoder_hidden_states[0].shape
979
+ ip_tokens = encoder_hidden_states[1][0]
980
+ encoder_hidden_states = torch.cat([encoder_hidden_states[0], ip_tokens], dim=1)
981
+
982
+ batch_size, sequence_length, _ = (
983
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
984
+ )
985
+
986
+ if attention_mask is not None:
987
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
988
+ # scaled_dot_product_attention expects attention_mask shape to be
989
+ # (batch, heads, source_length, target_length)
990
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
991
+
992
+ if attn.group_norm is not None:
993
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
994
+
995
+ query = attn.to_q(hidden_states)
996
+
997
+ if encoder_hidden_states is None:
998
+ encoder_hidden_states = hidden_states
999
+ else:
1000
+ # get encoder_hidden_states, ip_hidden_states
1001
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1002
+ encoder_hidden_states, ip_hidden_states = (
1003
+ encoder_hidden_states[:, :end_pos, :],
1004
+ encoder_hidden_states[:, end_pos:, :],
1005
+ )
1006
+ if attn.norm_cross:
1007
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1008
+
1009
+ key = attn.to_k(encoder_hidden_states)
1010
+ value = attn.to_v(encoder_hidden_states)
1011
+
1012
+ inner_dim = key.shape[-1]
1013
+ head_dim = inner_dim // attn.heads
1014
+
1015
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1016
+
1017
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1018
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1019
+
1020
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1021
+ # TODO: add support for attn.scale when we move to Torch 2.1
1022
+ hidden_states = F.scaled_dot_product_attention(
1023
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1024
+ )
1025
+
1026
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1027
+ hidden_states = hidden_states.to(query.dtype)
1028
+
1029
+ # for ip-adapter
1030
+ ip_key = self.to_k_ip(ip_hidden_states)
1031
+ ip_value = self.to_v_ip(ip_hidden_states)
1032
+
1033
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1034
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1035
+
1036
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1037
+ # TODO: add support for attn.scale when we move to Torch 2.1
1038
+ ip_hidden_states = F.scaled_dot_product_attention(
1039
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
1040
+ )
1041
+
1042
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1043
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
1044
+
1045
+ hidden_states = hidden_states + self.scale * ip_hidden_states
1046
+
1047
+ # linear proj
1048
+ hidden_states = attn.to_out[0](hidden_states)
1049
+ # dropout
1050
+ hidden_states = attn.to_out[1](hidden_states)
1051
+
1052
+ if input_ndim == 4:
1053
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1054
+
1055
+ if attn.residual_connection:
1056
+ hidden_states = hidden_states + residual
1057
+
1058
+ hidden_states = hidden_states / attn.rescale_output_factor
1059
+
1060
+ return hidden_states
1061
+
1062
+
1063
+ class TA_IPAttnProcessor2_0(torch.nn.Module):
1064
+ r"""
1065
+ Attention processor for IP-Adapater for PyTorch 2.0.
1066
+ Args:
1067
+ hidden_size (`int`):
1068
+ The hidden size of the attention layer.
1069
+ cross_attention_dim (`int`):
1070
+ The number of channels in the `encoder_hidden_states`.
1071
+ scale (`float`, defaults to 1.0):
1072
+ the weight scale of image prompt.
1073
+ num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
1074
+ The context length of the image features.
1075
+ """
1076
+
1077
+ def __init__(self, hidden_size, cross_attention_dim=None, time_embedding_dim: int = None, scale=1.0, num_tokens=4):
1078
+ super().__init__()
1079
+
1080
+ if not hasattr(F, "scaled_dot_product_attention"):
1081
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1082
+
1083
+ self.hidden_size = hidden_size
1084
+ self.cross_attention_dim = cross_attention_dim
1085
+ self.scale = scale
1086
+ self.num_tokens = num_tokens
1087
+
1088
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1089
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1090
+ self.ln_k_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
1091
+ self.ln_v_ip = AdaLayerNorm(hidden_size, time_embedding_dim)
1092
+
1093
+ def __call__(
1094
+ self,
1095
+ attn,
1096
+ hidden_states,
1097
+ encoder_hidden_states=None,
1098
+ attention_mask=None,
1099
+ external_kv=None,
1100
+ temb=None,
1101
+ ):
1102
+ assert temb is not None, "Timestep embedding is needed for a time-aware attention processor."
1103
+
1104
+ residual = hidden_states
1105
+
1106
+ if attn.spatial_norm is not None:
1107
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1108
+
1109
+ input_ndim = hidden_states.ndim
1110
+
1111
+ if input_ndim == 4:
1112
+ batch_size, channel, height, width = hidden_states.shape
1113
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1114
+
1115
+ if not isinstance(encoder_hidden_states, tuple):
1116
+ # get encoder_hidden_states, ip_hidden_states
1117
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1118
+ encoder_hidden_states, ip_hidden_states = (
1119
+ encoder_hidden_states[:, :end_pos, :],
1120
+ encoder_hidden_states[:, end_pos:, :],
1121
+ )
1122
+ else:
1123
+ # FIXME: now hard coded to single image prompt.
1124
+ batch_size, _, hid_dim = encoder_hidden_states[0].shape
1125
+ ip_hidden_states = encoder_hidden_states[1][0]
1126
+ encoder_hidden_states = encoder_hidden_states[0]
1127
+ batch_size, sequence_length, _ = (
1128
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1129
+ )
1130
+
1131
+ if attention_mask is not None:
1132
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1133
+ # scaled_dot_product_attention expects attention_mask shape to be
1134
+ # (batch, heads, source_length, target_length)
1135
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1136
+
1137
+ if attn.group_norm is not None:
1138
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1139
+
1140
+ query = attn.to_q(hidden_states)
1141
+
1142
+ if encoder_hidden_states is None:
1143
+ encoder_hidden_states = hidden_states
1144
+ else:
1145
+ if attn.norm_cross:
1146
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1147
+
1148
+ key = attn.to_k(encoder_hidden_states)
1149
+ value = attn.to_v(encoder_hidden_states)
1150
+
1151
+ if external_kv:
1152
+ key = torch.cat([key, external_kv.k], axis=1)
1153
+ value = torch.cat([value, external_kv.v], axis=1)
1154
+
1155
+ inner_dim = key.shape[-1]
1156
+ head_dim = inner_dim // attn.heads
1157
+
1158
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1159
+
1160
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1161
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1162
+
1163
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1164
+ # TODO: add support for attn.scale when we move to Torch 2.1
1165
+ hidden_states = F.scaled_dot_product_attention(
1166
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1167
+ )
1168
+
1169
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1170
+ hidden_states = hidden_states.to(query.dtype)
1171
+
1172
+ # for ip-adapter
1173
+ ip_key = self.to_k_ip(ip_hidden_states)
1174
+ ip_value = self.to_v_ip(ip_hidden_states)
1175
+
1176
+ # time-dependent adaLN
1177
+ ip_key = self.ln_k_ip(ip_key, temb)
1178
+ ip_value = self.ln_v_ip(ip_value, temb)
1179
+
1180
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1181
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1182
+
1183
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1184
+ # TODO: add support for attn.scale when we move to Torch 2.1
1185
+ ip_hidden_states = F.scaled_dot_product_attention(
1186
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
1187
+ )
1188
+
1189
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1190
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
1191
+
1192
+ hidden_states = hidden_states + self.scale * ip_hidden_states
1193
+
1194
+ # linear proj
1195
+ hidden_states = attn.to_out[0](hidden_states)
1196
+ # dropout
1197
+ hidden_states = attn.to_out[1](hidden_states)
1198
+
1199
+ if input_ndim == 4:
1200
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1201
+
1202
+ if attn.residual_connection:
1203
+ hidden_states = hidden_states + residual
1204
+
1205
+ hidden_states = hidden_states / attn.rescale_output_factor
1206
+
1207
+ return hidden_states
1208
+
1209
+
1210
+ ## for controlnet
1211
+ class CNAttnProcessor:
1212
+ r"""
1213
+ Default processor for performing attention-related computations.
1214
+ """
1215
+
1216
+ def __init__(self, num_tokens=4):
1217
+ self.num_tokens = num_tokens
1218
+
1219
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None):
1220
+ residual = hidden_states
1221
+
1222
+ if attn.spatial_norm is not None:
1223
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1224
+
1225
+ input_ndim = hidden_states.ndim
1226
+
1227
+ if input_ndim == 4:
1228
+ batch_size, channel, height, width = hidden_states.shape
1229
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1230
+
1231
+ batch_size, sequence_length, _ = (
1232
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1233
+ )
1234
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1235
+
1236
+ if attn.group_norm is not None:
1237
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1238
+
1239
+ query = attn.to_q(hidden_states)
1240
+
1241
+ if encoder_hidden_states is None:
1242
+ encoder_hidden_states = hidden_states
1243
+ else:
1244
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1245
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
1246
+ if attn.norm_cross:
1247
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1248
+
1249
+ key = attn.to_k(encoder_hidden_states)
1250
+ value = attn.to_v(encoder_hidden_states)
1251
+
1252
+ query = attn.head_to_batch_dim(query)
1253
+ key = attn.head_to_batch_dim(key)
1254
+ value = attn.head_to_batch_dim(value)
1255
+
1256
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1257
+ hidden_states = torch.bmm(attention_probs, value)
1258
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1259
+
1260
+ # linear proj
1261
+ hidden_states = attn.to_out[0](hidden_states)
1262
+ # dropout
1263
+ hidden_states = attn.to_out[1](hidden_states)
1264
+
1265
+ if input_ndim == 4:
1266
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1267
+
1268
+ if attn.residual_connection:
1269
+ hidden_states = hidden_states + residual
1270
+
1271
+ hidden_states = hidden_states / attn.rescale_output_factor
1272
+
1273
+ return hidden_states
1274
+
1275
+
1276
+ class CNAttnProcessor2_0:
1277
+ r"""
1278
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1279
+ """
1280
+
1281
+ def __init__(self, num_tokens=4):
1282
+ if not hasattr(F, "scaled_dot_product_attention"):
1283
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1284
+ self.num_tokens = num_tokens
1285
+
1286
+ def __call__(
1287
+ self,
1288
+ attn,
1289
+ hidden_states,
1290
+ encoder_hidden_states=None,
1291
+ attention_mask=None,
1292
+ temb=None,
1293
+ ):
1294
+ residual = hidden_states
1295
+
1296
+ if attn.spatial_norm is not None:
1297
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1298
+
1299
+ input_ndim = hidden_states.ndim
1300
+
1301
+ if input_ndim == 4:
1302
+ batch_size, channel, height, width = hidden_states.shape
1303
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1304
+
1305
+ batch_size, sequence_length, _ = (
1306
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1307
+ )
1308
+
1309
+ if attention_mask is not None:
1310
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1311
+ # scaled_dot_product_attention expects attention_mask shape to be
1312
+ # (batch, heads, source_length, target_length)
1313
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1314
+
1315
+ if attn.group_norm is not None:
1316
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1317
+
1318
+ query = attn.to_q(hidden_states)
1319
+
1320
+ if encoder_hidden_states is None:
1321
+ encoder_hidden_states = hidden_states
1322
+ else:
1323
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
1324
+ encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text
1325
+ if attn.norm_cross:
1326
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1327
+
1328
+ key = attn.to_k(encoder_hidden_states)
1329
+ value = attn.to_v(encoder_hidden_states)
1330
+
1331
+ inner_dim = key.shape[-1]
1332
+ head_dim = inner_dim // attn.heads
1333
+
1334
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1335
+
1336
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1337
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1338
+
1339
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1340
+ # TODO: add support for attn.scale when we move to Torch 2.1
1341
+ hidden_states = F.scaled_dot_product_attention(
1342
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1343
+ )
1344
+
1345
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1346
+ hidden_states = hidden_states.to(query.dtype)
1347
+
1348
+ # linear proj
1349
+ hidden_states = attn.to_out[0](hidden_states)
1350
+ # dropout
1351
+ hidden_states = attn.to_out[1](hidden_states)
1352
+
1353
+ if input_ndim == 4:
1354
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1355
+
1356
+ if attn.residual_connection:
1357
+ hidden_states = hidden_states + residual
1358
+
1359
+ hidden_states = hidden_states / attn.rescale_output_factor
1360
+
1361
+ return hidden_states
1362
+
1363
+
1364
+ def init_attn_proc(unet, ip_adapter_tokens=16, use_lcm=False, use_adaln=True, use_external_kv=False):
1365
+ attn_procs = {}
1366
+ unet_sd = unet.state_dict()
1367
+ for name in unet.attn_processors.keys():
1368
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1369
+ if name.startswith("mid_block"):
1370
+ hidden_size = unet.config.block_out_channels[-1]
1371
+ elif name.startswith("up_blocks"):
1372
+ block_id = int(name[len("up_blocks.")])
1373
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1374
+ elif name.startswith("down_blocks"):
1375
+ block_id = int(name[len("down_blocks.")])
1376
+ hidden_size = unet.config.block_out_channels[block_id]
1377
+ if cross_attention_dim is None:
1378
+ if use_external_kv:
1379
+ attn_procs[name] = AdditiveKV_AttnProcessor2_0(
1380
+ hidden_size=hidden_size,
1381
+ cross_attention_dim=cross_attention_dim,
1382
+ time_embedding_dim=1280,
1383
+ ) if hasattr(F, "scaled_dot_product_attention") else AdditiveKV_AttnProcessor()
1384
+ else:
1385
+ attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
1386
+ else:
1387
+ if use_adaln:
1388
+ layer_name = name.split(".processor")[0]
1389
+ if use_lcm:
1390
+ weights = {
1391
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.base_layer.weight"],
1392
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.base_layer.weight"],
1393
+ }
1394
+ else:
1395
+ weights = {
1396
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
1397
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
1398
+ }
1399
+ attn_procs[name] = TA_IPAttnProcessor2_0(
1400
+ hidden_size=hidden_size,
1401
+ cross_attention_dim=cross_attention_dim,
1402
+ num_tokens=ip_adapter_tokens,
1403
+ time_embedding_dim=1280,
1404
+ ) if hasattr(F, "scaled_dot_product_attention") else \
1405
+ TA_IPAttnProcessor(
1406
+ hidden_size=hidden_size,
1407
+ cross_attention_dim=cross_attention_dim,
1408
+ num_tokens=ip_adapter_tokens,
1409
+ time_embedding_dim=1280,
1410
+ )
1411
+ attn_procs[name].load_state_dict(weights, strict=False)
1412
+ else:
1413
+ attn_procs[name] = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
1414
+
1415
+ return attn_procs
1416
+
1417
+
1418
+ def init_aggregator_attn_proc(unet, use_adaln=False, split_attn=False):
1419
+ attn_procs = {}
1420
+ unet_sd = unet.state_dict()
1421
+ for name in unet.attn_processors.keys():
1422
+ # get layer name and hidden dim
1423
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1424
+ if name.startswith("mid_block"):
1425
+ hidden_size = unet.config.block_out_channels[-1]
1426
+ elif name.startswith("up_blocks"):
1427
+ block_id = int(name[len("up_blocks.")])
1428
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1429
+ elif name.startswith("down_blocks"):
1430
+ block_id = int(name[len("down_blocks.")])
1431
+ hidden_size = unet.config.block_out_channels[block_id]
1432
+ # init attn proc
1433
+ if split_attn:
1434
+ # layer_name = name.split(".processor")[0]
1435
+ # weights = {
1436
+ # "to_q_ref.weight": unet_sd[layer_name + ".to_q.weight"],
1437
+ # "to_k_ref.weight": unet_sd[layer_name + ".to_k.weight"],
1438
+ # "to_v_ref.weight": unet_sd[layer_name + ".to_v.weight"],
1439
+ # }
1440
+ attn_procs[name] = (
1441
+ sep_split_AttnProcessor2_0(
1442
+ hidden_size=hidden_size,
1443
+ cross_attention_dim=hidden_size,
1444
+ time_embedding_dim=1280,
1445
+ )
1446
+ if use_adaln
1447
+ else split_AttnProcessor2_0(
1448
+ hidden_size=hidden_size,
1449
+ cross_attention_dim=cross_attention_dim,
1450
+ time_embedding_dim=1280,
1451
+ )
1452
+ )
1453
+ # attn_procs[name].load_state_dict(weights, strict=False)
1454
+ else:
1455
+ attn_procs[name] = (
1456
+ AttnProcessor2_0(
1457
+ hidden_size=hidden_size,
1458
+ cross_attention_dim=hidden_size,
1459
+ )
1460
+ if hasattr(F, "scaled_dot_product_attention")
1461
+ else AttnProcessor(
1462
+ hidden_size=hidden_size,
1463
+ cross_attention_dim=hidden_size,
1464
+ )
1465
+ )
1466
+
1467
+ return attn_procs
module/ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from typing import List
4
+ from collections import namedtuple, OrderedDict
5
+
6
+ def is_torch2_available():
7
+ return hasattr(torch.nn.functional, "scaled_dot_product_attention")
8
+
9
+ if is_torch2_available():
10
+ from .attention_processor import (
11
+ AttnProcessor2_0 as AttnProcessor,
12
+ )
13
+ from .attention_processor import (
14
+ CNAttnProcessor2_0 as CNAttnProcessor,
15
+ )
16
+ from .attention_processor import (
17
+ IPAttnProcessor2_0 as IPAttnProcessor,
18
+ )
19
+ from .attention_processor import (
20
+ TA_IPAttnProcessor2_0 as TA_IPAttnProcessor,
21
+ )
22
+ else:
23
+ from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor, TA_IPAttnProcessor
24
+
25
+
26
+ class ImageProjModel(torch.nn.Module):
27
+ """Projection Model"""
28
+
29
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
30
+ super().__init__()
31
+
32
+ self.cross_attention_dim = cross_attention_dim
33
+ self.clip_extra_context_tokens = clip_extra_context_tokens
34
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
35
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
36
+
37
+ def forward(self, image_embeds):
38
+ embeds = image_embeds
39
+ clip_extra_context_tokens = self.proj(embeds).reshape(
40
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
41
+ )
42
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
43
+ return clip_extra_context_tokens
44
+
45
+
46
+ class MLPProjModel(torch.nn.Module):
47
+ """SD model with image prompt"""
48
+ def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280):
49
+ super().__init__()
50
+
51
+ self.proj = torch.nn.Sequential(
52
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
53
+ torch.nn.GELU(),
54
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
55
+ torch.nn.LayerNorm(cross_attention_dim)
56
+ )
57
+
58
+ def forward(self, image_embeds):
59
+ clip_extra_context_tokens = self.proj(image_embeds)
60
+ return clip_extra_context_tokens
61
+
62
+
63
+ class MultiIPAdapterImageProjection(torch.nn.Module):
64
+ def __init__(self, IPAdapterImageProjectionLayers):
65
+ super().__init__()
66
+ self.image_projection_layers = torch.nn.ModuleList(IPAdapterImageProjectionLayers)
67
+
68
+ def forward(self, image_embeds: List[torch.FloatTensor]):
69
+ projected_image_embeds = []
70
+
71
+ # currently, we accept `image_embeds` as
72
+ # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim]
73
+ # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim]
74
+ if not isinstance(image_embeds, list):
75
+ image_embeds = [image_embeds.unsqueeze(1)]
76
+
77
+ if len(image_embeds) != len(self.image_projection_layers):
78
+ raise ValueError(
79
+ f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}"
80
+ )
81
+
82
+ for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers):
83
+ batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
84
+ image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
85
+ image_embed = image_projection_layer(image_embed)
86
+ # image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
87
+
88
+ projected_image_embeds.append(image_embed)
89
+
90
+ return projected_image_embeds
91
+
92
+
93
+ class IPAdapter(torch.nn.Module):
94
+ """IP-Adapter"""
95
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
96
+ super().__init__()
97
+ self.unet = unet
98
+ self.image_proj = image_proj_model
99
+ self.ip_adapter = adapter_modules
100
+
101
+ if ckpt_path is not None:
102
+ self.load_from_checkpoint(ckpt_path)
103
+
104
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
105
+ ip_tokens = self.image_proj(image_embeds)
106
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
107
+ # Predict the noise residual
108
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
109
+ return noise_pred
110
+
111
+ def load_from_checkpoint(self, ckpt_path: str):
112
+ # Calculate original checksums
113
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
114
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
115
+
116
+ state_dict = torch.load(ckpt_path, map_location="cpu")
117
+ keys = list(state_dict.keys())
118
+ if keys != ["image_proj", "ip_adapter"]:
119
+ state_dict = revise_state_dict(state_dict)
120
+
121
+ # Load state dict for image_proj_model and adapter_modules
122
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=True)
123
+ self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=True)
124
+
125
+ # Calculate new checksums
126
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
127
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
128
+
129
+ # Verify if the weights have changed
130
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
131
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
132
+
133
+
134
+ class IPAdapterPlus(torch.nn.Module):
135
+ """IP-Adapter"""
136
+ def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
137
+ super().__init__()
138
+ self.unet = unet
139
+ self.image_proj = image_proj_model
140
+ self.ip_adapter = adapter_modules
141
+
142
+ if ckpt_path is not None:
143
+ self.load_from_checkpoint(ckpt_path)
144
+
145
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
146
+ ip_tokens = self.image_proj(image_embeds)
147
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
148
+ # Predict the noise residual
149
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
150
+ return noise_pred
151
+
152
+ def load_from_checkpoint(self, ckpt_path: str):
153
+ # Calculate original checksums
154
+ orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
155
+ orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
156
+ org_unet_sum = []
157
+ for attn_name, attn_proc in self.unet.attn_processors.items():
158
+ if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)):
159
+ org_unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()])))
160
+ org_unet_sum = torch.sum(torch.stack(org_unet_sum))
161
+
162
+ state_dict = torch.load(ckpt_path, map_location="cpu")
163
+ keys = list(state_dict.keys())
164
+ if keys != ["image_proj", "ip_adapter"]:
165
+ state_dict = revise_state_dict(state_dict)
166
+
167
+ # Check if 'latents' exists in both the saved state_dict and the current model's state_dict
168
+ strict_load_image_proj_model = True
169
+ if "latents" in state_dict["image_proj"] and "latents" in self.image_proj.state_dict():
170
+ # Check if the shapes are mismatched
171
+ if state_dict["image_proj"]["latents"].shape != self.image_proj.state_dict()["latents"].shape:
172
+ print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
173
+ print("Removing 'latents' from checkpoint and loading the rest of the weights.")
174
+ del state_dict["image_proj"]["latents"]
175
+ strict_load_image_proj_model = False
176
+
177
+ # Load state dict for image_proj_model and adapter_modules
178
+ self.image_proj.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
179
+ missing_key, unexpected_key = self.ip_adapter.load_state_dict(state_dict["ip_adapter"], strict=False)
180
+ if len(missing_key) > 0:
181
+ for ms in missing_key:
182
+ if "ln" not in ms:
183
+ raise ValueError(f"Missing key in adapter_modules: {len(missing_key)}")
184
+ if len(unexpected_key) > 0:
185
+ raise ValueError(f"Unexpected key in adapter_modules: {len(unexpected_key)}")
186
+
187
+ # Calculate new checksums
188
+ new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj.parameters()]))
189
+ new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.ip_adapter.parameters()]))
190
+
191
+ # Verify if the weights loaded to unet
192
+ unet_sum = []
193
+ for attn_name, attn_proc in self.unet.attn_processors.items():
194
+ if isinstance(attn_proc, (TA_IPAttnProcessor, IPAttnProcessor)):
195
+ unet_sum.append(torch.sum(torch.stack([torch.sum(p) for p in attn_proc.parameters()])))
196
+ unet_sum = torch.sum(torch.stack(unet_sum))
197
+
198
+ assert org_unet_sum != unet_sum, "Weights of adapter_modules in unet did not change!"
199
+ assert (unet_sum - new_adapter_sum < 1e-4), "Weights of adapter_modules did not load to unet!"
200
+
201
+ # Verify if the weights have changed
202
+ assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
203
+ assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_mod`ules did not change!"
204
+
205
+
206
+ class IPAdapterXL(IPAdapter):
207
+ """SDXL"""
208
+
209
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
210
+ ip_tokens = self.image_proj(image_embeds)
211
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
212
+ # Predict the noise residual
213
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
214
+ return noise_pred
215
+
216
+
217
+ class IPAdapterPlusXL(IPAdapterPlus):
218
+ """IP-Adapter with fine-grained features"""
219
+
220
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
221
+ ip_tokens = self.image_proj(image_embeds)
222
+ encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
223
+ # Predict the noise residual
224
+ noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs).sample
225
+ return noise_pred
226
+
227
+
228
+ class IPAdapterFull(IPAdapterPlus):
229
+ """IP-Adapter with full features"""
230
+
231
+ def init_proj(self):
232
+ image_proj_model = MLPProjModel(
233
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
234
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
235
+ ).to(self.device, dtype=torch.float16)
236
+ return image_proj_model
module/ip_adapter/resampler.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class Resampler(nn.Module):
82
+ def __init__(
83
+ self,
84
+ dim=1280,
85
+ depth=4,
86
+ dim_head=64,
87
+ heads=20,
88
+ num_queries=64,
89
+ embedding_dim=768,
90
+ output_dim=1024,
91
+ ff_mult=4,
92
+ max_seq_len: int = 257, # CLIP tokens + CLS token
93
+ apply_pos_emb: bool = False,
94
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
95
+ ):
96
+ super().__init__()
97
+ self.pos_emb = nn.Embedding(max_seq_len, embedding_dim) if apply_pos_emb else None
98
+
99
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
100
+
101
+ self.proj_in = nn.Linear(embedding_dim, dim)
102
+
103
+ self.proj_out = nn.Linear(dim, output_dim)
104
+ self.norm_out = nn.LayerNorm(output_dim)
105
+
106
+ self.to_latents_from_mean_pooled_seq = (
107
+ nn.Sequential(
108
+ nn.LayerNorm(dim),
109
+ nn.Linear(dim, dim * num_latents_mean_pooled),
110
+ Rearrange("b (n d) -> b n d", n=num_latents_mean_pooled),
111
+ )
112
+ if num_latents_mean_pooled > 0
113
+ else None
114
+ )
115
+
116
+ self.layers = nn.ModuleList([])
117
+ for _ in range(depth):
118
+ self.layers.append(
119
+ nn.ModuleList(
120
+ [
121
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
122
+ FeedForward(dim=dim, mult=ff_mult),
123
+ ]
124
+ )
125
+ )
126
+
127
+ def forward(self, x):
128
+ if self.pos_emb is not None:
129
+ n, device = x.shape[1], x.device
130
+ pos_emb = self.pos_emb(torch.arange(n, device=device))
131
+ x = x + pos_emb
132
+
133
+ latents = self.latents.repeat(x.size(0), 1, 1)
134
+
135
+ x = self.proj_in(x)
136
+
137
+ if self.to_latents_from_mean_pooled_seq:
138
+ meanpooled_seq = masked_mean(x, dim=1, mask=torch.ones(x.shape[:2], device=x.device, dtype=torch.bool))
139
+ meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
140
+ latents = torch.cat((meanpooled_latents, latents), dim=-2)
141
+
142
+ for attn, ff in self.layers:
143
+ latents = attn(x, latents) + latents
144
+ latents = ff(latents) + latents
145
+
146
+ latents = self.proj_out(latents)
147
+ return self.norm_out(latents)
148
+
149
+
150
+ def masked_mean(t, *, dim, mask=None):
151
+ if mask is None:
152
+ return t.mean(dim=dim)
153
+
154
+ denom = mask.sum(dim=dim, keepdim=True)
155
+ mask = rearrange(mask, "b n -> b n 1")
156
+ masked_t = t.masked_fill(~mask, 0.0)
157
+
158
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
module/ip_adapter/utils.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import namedtuple, OrderedDict
3
+ from safetensors import safe_open
4
+ from .attention_processor import init_attn_proc
5
+ from .ip_adapter import MultiIPAdapterImageProjection
6
+ from .resampler import Resampler
7
+ from transformers import (
8
+ AutoModel, AutoImageProcessor,
9
+ CLIPVisionModelWithProjection, CLIPImageProcessor)
10
+
11
+
12
+ def init_adapter_in_unet(
13
+ unet,
14
+ image_proj_model=None,
15
+ pretrained_model_path_or_dict=None,
16
+ adapter_tokens=64,
17
+ embedding_dim=None,
18
+ use_lcm=False,
19
+ use_adaln=True,
20
+ ):
21
+ device = unet.device
22
+ dtype = unet.dtype
23
+ if image_proj_model is None:
24
+ assert embedding_dim is not None, "embedding_dim must be provided if image_proj_model is None."
25
+ image_proj_model = Resampler(
26
+ embedding_dim=embedding_dim,
27
+ output_dim=unet.config.cross_attention_dim,
28
+ num_queries=adapter_tokens,
29
+ )
30
+ if pretrained_model_path_or_dict is not None:
31
+ if not isinstance(pretrained_model_path_or_dict, dict):
32
+ if pretrained_model_path_or_dict.endswith(".safetensors"):
33
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
34
+ with safe_open(pretrained_model_path_or_dict, framework="pt", device=unet.device) as f:
35
+ for key in f.keys():
36
+ if key.startswith("image_proj."):
37
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
38
+ elif key.startswith("ip_adapter."):
39
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
40
+ else:
41
+ state_dict = torch.load(pretrained_model_path_or_dict, map_location=unet.device)
42
+ else:
43
+ state_dict = pretrained_model_path_or_dict
44
+ keys = list(state_dict.keys())
45
+ if "image_proj" not in keys and "ip_adapter" not in keys:
46
+ state_dict = revise_state_dict(state_dict)
47
+
48
+ # Creat IP cross-attention in unet.
49
+ attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
50
+ unet.set_attn_processor(attn_procs)
51
+
52
+ # Load pretrinaed model if needed.
53
+ if pretrained_model_path_or_dict is not None:
54
+ if "ip_adapter" in state_dict.keys():
55
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
56
+ missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
57
+ for mk in missing:
58
+ if "ln" not in mk:
59
+ raise ValueError(f"Missing keys in adapter_modules: {missing}")
60
+ if "image_proj" in state_dict.keys():
61
+ image_proj_model.load_state_dict(state_dict["image_proj"])
62
+
63
+ # Load image projectors into iterable ModuleList.
64
+ image_projection_layers = []
65
+ image_projection_layers.append(image_proj_model)
66
+ unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
67
+
68
+ # Adjust unet config to handle addtional ip hidden states.
69
+ unet.config.encoder_hid_dim_type = "ip_image_proj"
70
+ unet.to(dtype=dtype, device=device)
71
+
72
+
73
+ def load_adapter_to_pipe(
74
+ pipe,
75
+ pretrained_model_path_or_dict,
76
+ image_encoder_or_path=None,
77
+ feature_extractor_or_path=None,
78
+ use_clip_encoder=False,
79
+ adapter_tokens=64,
80
+ use_lcm=False,
81
+ use_adaln=True,
82
+ ):
83
+
84
+ if not isinstance(pretrained_model_path_or_dict, dict):
85
+ if pretrained_model_path_or_dict.endswith(".safetensors"):
86
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
87
+ with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
88
+ for key in f.keys():
89
+ if key.startswith("image_proj."):
90
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
91
+ elif key.startswith("ip_adapter."):
92
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
93
+ else:
94
+ state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
95
+ else:
96
+ state_dict = pretrained_model_path_or_dict
97
+ keys = list(state_dict.keys())
98
+ if "image_proj" not in keys and "ip_adapter" not in keys:
99
+ state_dict = revise_state_dict(state_dict)
100
+
101
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
102
+ if image_encoder_or_path is not None:
103
+ if isinstance(image_encoder_or_path, str):
104
+ feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
105
+
106
+ image_encoder_or_path = (
107
+ CLIPVisionModelWithProjection.from_pretrained(
108
+ image_encoder_or_path
109
+ ) if use_clip_encoder else
110
+ AutoModel.from_pretrained(image_encoder_or_path)
111
+ )
112
+
113
+ if feature_extractor_or_path is not None:
114
+ if isinstance(feature_extractor_or_path, str):
115
+ feature_extractor_or_path = (
116
+ CLIPImageProcessor() if use_clip_encoder else
117
+ AutoImageProcessor.from_pretrained(feature_extractor_or_path)
118
+ )
119
+
120
+ # create image encoder if it has not been registered to the pipeline yet
121
+ if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
122
+ image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
123
+ pipe.register_modules(image_encoder=image_encoder)
124
+ else:
125
+ image_encoder = pipe.image_encoder
126
+
127
+ # create feature extractor if it has not been registered to the pipeline yet
128
+ if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
129
+ feature_extractor = feature_extractor_or_path
130
+ pipe.register_modules(feature_extractor=feature_extractor)
131
+ else:
132
+ feature_extractor = pipe.feature_extractor
133
+
134
+ # load adapter into unet
135
+ unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
136
+ attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
137
+ unet.set_attn_processor(attn_procs)
138
+ image_proj_model = Resampler(
139
+ embedding_dim=image_encoder.config.hidden_size,
140
+ output_dim=unet.config.cross_attention_dim,
141
+ num_queries=adapter_tokens,
142
+ )
143
+
144
+ # Load pretrinaed model if needed.
145
+ if "ip_adapter" in state_dict.keys():
146
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
147
+ missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
148
+ for mk in missing:
149
+ if "ln" not in mk:
150
+ raise ValueError(f"Missing keys in adapter_modules: {missing}")
151
+ if "image_proj" in state_dict.keys():
152
+ image_proj_model.load_state_dict(state_dict["image_proj"])
153
+
154
+ # convert IP-Adapter Image Projection layers to diffusers
155
+ image_projection_layers = []
156
+ image_projection_layers.append(image_proj_model)
157
+ unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
158
+
159
+ # Adjust unet config to handle addtional ip hidden states.
160
+ unet.config.encoder_hid_dim_type = "ip_image_proj"
161
+ unet.to(dtype=pipe.dtype, device=pipe.device)
162
+
163
+
164
+ def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
165
+ new_state_dict = OrderedDict()
166
+ new_state_dict["image_proj"] = OrderedDict()
167
+ new_state_dict["ip_adapter"] = OrderedDict()
168
+ if isinstance(old_state_dict_or_path, str):
169
+ old_state_dict = torch.load(old_state_dict_or_path, map_location=map_location)
170
+ else:
171
+ old_state_dict = old_state_dict_or_path
172
+ for name, weight in old_state_dict.items():
173
+ if name.startswith("image_proj_model."):
174
+ new_state_dict["image_proj"][name[len("image_proj_model."):]] = weight
175
+ elif name.startswith("adapter_modules."):
176
+ new_state_dict["ip_adapter"][name[len("adapter_modules."):]] = weight
177
+ return new_state_dict
178
+
179
+
180
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
181
+ def encode_image(image_encoder, feature_extractor, image, device, num_images_per_prompt, output_hidden_states=None):
182
+ dtype = next(image_encoder.parameters()).dtype
183
+
184
+ if not isinstance(image, torch.Tensor):
185
+ image = feature_extractor(image, return_tensors="pt").pixel_values
186
+
187
+ image = image.to(device=device, dtype=dtype)
188
+ if output_hidden_states:
189
+ image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
190
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
191
+ return image_enc_hidden_states
192
+ else:
193
+ if isinstance(image_encoder, CLIPVisionModelWithProjection):
194
+ # CLIP image encoder.
195
+ image_embeds = image_encoder(image).image_embeds
196
+ else:
197
+ # DINO image encoder.
198
+ image_embeds = image_encoder(image).last_hidden_state
199
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
200
+ return image_embeds
201
+
202
+
203
+ def prepare_training_image_embeds(
204
+ image_encoder, feature_extractor,
205
+ ip_adapter_image, ip_adapter_image_embeds,
206
+ device, drop_rate, output_hidden_state, idx_to_replace=None
207
+ ):
208
+ if ip_adapter_image_embeds is None:
209
+ if not isinstance(ip_adapter_image, list):
210
+ ip_adapter_image = [ip_adapter_image]
211
+
212
+ # if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
213
+ # raise ValueError(
214
+ # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
215
+ # )
216
+
217
+ image_embeds = []
218
+ for single_ip_adapter_image in ip_adapter_image:
219
+ if idx_to_replace is None:
220
+ idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
221
+ zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
222
+ single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
223
+ single_image_embeds = encode_image(
224
+ image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
225
+ )
226
+ single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
227
+
228
+ image_embeds.append(single_image_embeds)
229
+ else:
230
+ repeat_dims = [1]
231
+ image_embeds = []
232
+ for single_image_embeds in ip_adapter_image_embeds:
233
+ if do_classifier_free_guidance:
234
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
235
+ single_image_embeds = single_image_embeds.repeat(
236
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
237
+ )
238
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
239
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
240
+ )
241
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
242
+ else:
243
+ single_image_embeds = single_image_embeds.repeat(
244
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
245
+ )
246
+ image_embeds.append(single_image_embeds)
247
+
248
+ return image_embeds
module/min_sdxl.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from minSDXL by Simo Ryu:
2
+ # https://github.com/cloneofsimo/minSDXL ,
3
+ # which is in turn modified from the original code of:
4
+ # https://github.com/huggingface/diffusers
5
+ # So has APACHE 2.0 license
6
+
7
+ from typing import Optional, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import math
13
+ import inspect
14
+
15
+ from collections import namedtuple
16
+
17
+ from torch.fft import fftn, fftshift, ifftn, ifftshift
18
+
19
+ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
20
+
21
+ # Implementation of FreeU for minsdxl
22
+
23
+ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
24
+ """Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).
25
+
26
+ This version of the method comes from here:
27
+ https://github.com/huggingface/diffusers/pull/5164#issuecomment-1732638706
28
+ """
29
+ x = x_in
30
+ B, C, H, W = x.shape
31
+
32
+ # Non-power of 2 images must be float32
33
+ if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
34
+ x = x.to(dtype=torch.float32)
35
+
36
+ # FFT
37
+ x_freq = fftn(x, dim=(-2, -1))
38
+ x_freq = fftshift(x_freq, dim=(-2, -1))
39
+
40
+ B, C, H, W = x_freq.shape
41
+ mask = torch.ones((B, C, H, W), device=x.device)
42
+
43
+ crow, ccol = H // 2, W // 2
44
+ mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale
45
+ x_freq = x_freq * mask
46
+
47
+ # IFFT
48
+ x_freq = ifftshift(x_freq, dim=(-2, -1))
49
+ x_filtered = ifftn(x_freq, dim=(-2, -1)).real
50
+
51
+ return x_filtered.to(dtype=x_in.dtype)
52
+
53
+
54
+ def apply_freeu(
55
+ resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs):
56
+ """Applies the FreeU mechanism as introduced in https:
57
+ //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU.
58
+
59
+ Args:
60
+ resolution_idx (`int`): Integer denoting the UNet block where FreeU is being applied.
61
+ hidden_states (`torch.Tensor`): Inputs to the underlying block.
62
+ res_hidden_states (`torch.Tensor`): Features from the skip block corresponding to the underlying block.
63
+ s1 (`float`): Scaling factor for stage 1 to attenuate the contributions of the skip features.
64
+ s2 (`float`): Scaling factor for stage 2 to attenuate the contributions of the skip features.
65
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
66
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
67
+ """
68
+ if resolution_idx == 0:
69
+ num_half_channels = hidden_states.shape[1] // 2
70
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"]
71
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"])
72
+ if resolution_idx == 1:
73
+ num_half_channels = hidden_states.shape[1] // 2
74
+ hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"]
75
+ res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
76
+
77
+ return hidden_states, res_hidden_states
78
+
79
+ # Diffusers-style LoRA to keep everything in the min_sdxl.py file
80
+
81
+ class LoRALinearLayer(nn.Module):
82
+ r"""
83
+ A linear layer that is used with LoRA.
84
+
85
+ Parameters:
86
+ in_features (`int`):
87
+ Number of input features.
88
+ out_features (`int`):
89
+ Number of output features.
90
+ rank (`int`, `optional`, defaults to 4):
91
+ The rank of the LoRA layer.
92
+ network_alpha (`float`, `optional`, defaults to `None`):
93
+ The value of the network alpha used for stable learning and preventing underflow. This value has the same
94
+ meaning as the `--network_alpha` option in the kohya-ss trainer script. See
95
+ https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
96
+ device (`torch.device`, `optional`, defaults to `None`):
97
+ The device to use for the layer's weights.
98
+ dtype (`torch.dtype`, `optional`, defaults to `None`):
99
+ The dtype to use for the layer's weights.
100
+ """
101
+
102
+ def __init__(
103
+ self,
104
+ in_features: int,
105
+ out_features: int,
106
+ rank: int = 4,
107
+ network_alpha: Optional[float] = None,
108
+ device: Optional[Union[torch.device, str]] = None,
109
+ dtype: Optional[torch.dtype] = None,
110
+ ):
111
+ super().__init__()
112
+
113
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
114
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
115
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
116
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
117
+ self.network_alpha = network_alpha
118
+ self.rank = rank
119
+ self.out_features = out_features
120
+ self.in_features = in_features
121
+
122
+ nn.init.normal_(self.down.weight, std=1 / rank)
123
+ nn.init.zeros_(self.up.weight)
124
+
125
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
126
+ orig_dtype = hidden_states.dtype
127
+ dtype = self.down.weight.dtype
128
+
129
+ down_hidden_states = self.down(hidden_states.to(dtype))
130
+ up_hidden_states = self.up(down_hidden_states)
131
+
132
+ if self.network_alpha is not None:
133
+ up_hidden_states *= self.network_alpha / self.rank
134
+
135
+ return up_hidden_states.to(orig_dtype)
136
+
137
+ class LoRACompatibleLinear(nn.Linear):
138
+ """
139
+ A Linear layer that can be used with LoRA.
140
+ """
141
+
142
+ def __init__(self, *args, lora_layer: Optional[LoRALinearLayer] = None, **kwargs):
143
+ super().__init__(*args, **kwargs)
144
+ self.lora_layer = lora_layer
145
+
146
+ def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
147
+ self.lora_layer = lora_layer
148
+
149
+ def _fuse_lora(self, lora_scale: float = 1.0, safe_fusing: bool = False):
150
+ if self.lora_layer is None:
151
+ return
152
+
153
+ dtype, device = self.weight.data.dtype, self.weight.data.device
154
+
155
+ w_orig = self.weight.data.float()
156
+ w_up = self.lora_layer.up.weight.data.float()
157
+ w_down = self.lora_layer.down.weight.data.float()
158
+
159
+ if self.lora_layer.network_alpha is not None:
160
+ w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
161
+
162
+ fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
163
+
164
+ if safe_fusing and torch.isnan(fused_weight).any().item():
165
+ raise ValueError(
166
+ "This LoRA weight seems to be broken. "
167
+ f"Encountered NaN values when trying to fuse LoRA weights for {self}."
168
+ "LoRA weights will not be fused."
169
+ )
170
+
171
+ self.weight.data = fused_weight.to(device=device, dtype=dtype)
172
+
173
+ # we can drop the lora layer now
174
+ self.lora_layer = None
175
+
176
+ # offload the up and down matrices to CPU to not blow the memory
177
+ self.w_up = w_up.cpu()
178
+ self.w_down = w_down.cpu()
179
+ self._lora_scale = lora_scale
180
+
181
+ def _unfuse_lora(self):
182
+ if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
183
+ return
184
+
185
+ fused_weight = self.weight.data
186
+ dtype, device = fused_weight.dtype, fused_weight.device
187
+
188
+ w_up = self.w_up.to(device=device).float()
189
+ w_down = self.w_down.to(device).float()
190
+
191
+ unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
192
+ self.weight.data = unfused_weight.to(device=device, dtype=dtype)
193
+
194
+ self.w_up = None
195
+ self.w_down = None
196
+
197
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
198
+ if self.lora_layer is None:
199
+ out = super().forward(hidden_states)
200
+ return out
201
+ else:
202
+ out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
203
+ return out
204
+
205
+ class Timesteps(nn.Module):
206
+ def __init__(self, num_channels: int = 320):
207
+ super().__init__()
208
+ self.num_channels = num_channels
209
+
210
+ def forward(self, timesteps):
211
+ half_dim = self.num_channels // 2
212
+ exponent = -math.log(10000) * torch.arange(
213
+ half_dim, dtype=torch.float32, device=timesteps.device
214
+ )
215
+ exponent = exponent / (half_dim - 0.0)
216
+
217
+ emb = torch.exp(exponent)
218
+ emb = timesteps[:, None].float() * emb[None, :]
219
+
220
+ sin_emb = torch.sin(emb)
221
+ cos_emb = torch.cos(emb)
222
+ emb = torch.cat([cos_emb, sin_emb], dim=-1)
223
+
224
+ return emb
225
+
226
+
227
+ class TimestepEmbedding(nn.Module):
228
+ def __init__(self, in_features, out_features):
229
+ super(TimestepEmbedding, self).__init__()
230
+ self.linear_1 = nn.Linear(in_features, out_features, bias=True)
231
+ self.act = nn.SiLU()
232
+ self.linear_2 = nn.Linear(out_features, out_features, bias=True)
233
+
234
+ def forward(self, sample):
235
+ sample = self.linear_1(sample)
236
+ sample = self.act(sample)
237
+ sample = self.linear_2(sample)
238
+
239
+ return sample
240
+
241
+
242
+ class ResnetBlock2D(nn.Module):
243
+ def __init__(self, in_channels, out_channels, conv_shortcut=True):
244
+ super(ResnetBlock2D, self).__init__()
245
+ self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-05, affine=True)
246
+ self.conv1 = nn.Conv2d(
247
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
248
+ )
249
+ self.time_emb_proj = nn.Linear(1280, out_channels, bias=True)
250
+ self.norm2 = nn.GroupNorm(32, out_channels, eps=1e-05, affine=True)
251
+ self.dropout = nn.Dropout(p=0.0, inplace=False)
252
+ self.conv2 = nn.Conv2d(
253
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
254
+ )
255
+ self.nonlinearity = nn.SiLU()
256
+ self.conv_shortcut = None
257
+ if conv_shortcut:
258
+ self.conv_shortcut = nn.Conv2d(
259
+ in_channels, out_channels, kernel_size=1, stride=1
260
+ )
261
+
262
+ def forward(self, input_tensor, temb):
263
+ hidden_states = input_tensor
264
+ hidden_states = self.norm1(hidden_states)
265
+ hidden_states = self.nonlinearity(hidden_states)
266
+
267
+ hidden_states = self.conv1(hidden_states)
268
+
269
+ temb = self.nonlinearity(temb)
270
+ temb = self.time_emb_proj(temb)[:, :, None, None]
271
+ hidden_states = hidden_states + temb
272
+ hidden_states = self.norm2(hidden_states)
273
+
274
+ hidden_states = self.nonlinearity(hidden_states)
275
+ hidden_states = self.dropout(hidden_states)
276
+ hidden_states = self.conv2(hidden_states)
277
+
278
+ if self.conv_shortcut is not None:
279
+ input_tensor = self.conv_shortcut(input_tensor)
280
+
281
+ output_tensor = input_tensor + hidden_states
282
+
283
+ return output_tensor
284
+
285
+
286
+ class Attention(nn.Module):
287
+ def __init__(
288
+ self, inner_dim, cross_attention_dim=None, num_heads=None, dropout=0.0, processor=None, scale_qk=True
289
+ ):
290
+ super(Attention, self).__init__()
291
+ if num_heads is None:
292
+ self.head_dim = 64
293
+ self.num_heads = inner_dim // self.head_dim
294
+ else:
295
+ self.num_heads = num_heads
296
+ self.head_dim = inner_dim // num_heads
297
+
298
+ self.scale = self.head_dim**-0.5
299
+ if cross_attention_dim is None:
300
+ cross_attention_dim = inner_dim
301
+ self.to_q = LoRACompatibleLinear(inner_dim, inner_dim, bias=False)
302
+ self.to_k = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
303
+ self.to_v = LoRACompatibleLinear(cross_attention_dim, inner_dim, bias=False)
304
+
305
+ self.to_out = nn.ModuleList(
306
+ [LoRACompatibleLinear(inner_dim, inner_dim), nn.Dropout(dropout, inplace=False)]
307
+ )
308
+
309
+ self.scale_qk = scale_qk
310
+ if processor is None:
311
+ processor = (
312
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
313
+ )
314
+ self.set_processor(processor)
315
+
316
+ def forward(
317
+ self,
318
+ hidden_states: torch.FloatTensor,
319
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
320
+ attention_mask: Optional[torch.FloatTensor] = None,
321
+ **cross_attention_kwargs,
322
+ ) -> torch.Tensor:
323
+ r"""
324
+ The forward method of the `Attention` class.
325
+
326
+ Args:
327
+ hidden_states (`torch.Tensor`):
328
+ The hidden states of the query.
329
+ encoder_hidden_states (`torch.Tensor`, *optional*):
330
+ The hidden states of the encoder.
331
+ attention_mask (`torch.Tensor`, *optional*):
332
+ The attention mask to use. If `None`, no mask is applied.
333
+ **cross_attention_kwargs:
334
+ Additional keyword arguments to pass along to the cross attention.
335
+
336
+ Returns:
337
+ `torch.Tensor`: The output of the attention layer.
338
+ """
339
+ # The `Attention` class can call different attention processors / attention functions
340
+ # here we simply pass along all tensors to the selected processor class
341
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
342
+
343
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
344
+ unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
345
+ if len(unused_kwargs) > 0:
346
+ print(
347
+ f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
348
+ )
349
+ cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
350
+
351
+ return self.processor(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=encoder_hidden_states,
355
+ attention_mask=attention_mask,
356
+ **cross_attention_kwargs,
357
+ )
358
+
359
+ def orig_forward(self, hidden_states, encoder_hidden_states=None):
360
+ q = self.to_q(hidden_states)
361
+ k = (
362
+ self.to_k(encoder_hidden_states)
363
+ if encoder_hidden_states is not None
364
+ else self.to_k(hidden_states)
365
+ )
366
+ v = (
367
+ self.to_v(encoder_hidden_states)
368
+ if encoder_hidden_states is not None
369
+ else self.to_v(hidden_states)
370
+ )
371
+ b, t, c = q.size()
372
+
373
+ q = q.view(q.size(0), q.size(1), self.num_heads, self.head_dim).transpose(1, 2)
374
+ k = k.view(k.size(0), k.size(1), self.num_heads, self.head_dim).transpose(1, 2)
375
+ v = v.view(v.size(0), v.size(1), self.num_heads, self.head_dim).transpose(1, 2)
376
+
377
+ # scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
378
+ # attn_weights = torch.softmax(scores, dim=-1)
379
+ # attn_output = torch.matmul(attn_weights, v)
380
+
381
+ attn_output = F.scaled_dot_product_attention(
382
+ q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=self.scale,
383
+ )
384
+
385
+ attn_output = attn_output.transpose(1, 2).contiguous().view(b, t, c)
386
+
387
+ for layer in self.to_out:
388
+ attn_output = layer(attn_output)
389
+
390
+ return attn_output
391
+
392
+ def set_processor(self, processor) -> None:
393
+ r"""
394
+ Set the attention processor to use.
395
+
396
+ Args:
397
+ processor (`AttnProcessor`):
398
+ The attention processor to use.
399
+ """
400
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
401
+ # pop `processor` from `self._modules`
402
+ if (
403
+ hasattr(self, "processor")
404
+ and isinstance(self.processor, torch.nn.Module)
405
+ and not isinstance(processor, torch.nn.Module)
406
+ ):
407
+ print(f"You are removing possibly trained weights of {self.processor} with {processor}")
408
+ self._modules.pop("processor")
409
+
410
+ self.processor = processor
411
+
412
+ def get_processor(self, return_deprecated_lora: bool = False):
413
+ r"""
414
+ Get the attention processor in use.
415
+
416
+ Args:
417
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
418
+ Set to `True` to return the deprecated LoRA attention processor.
419
+
420
+ Returns:
421
+ "AttentionProcessor": The attention processor in use.
422
+ """
423
+ if not return_deprecated_lora:
424
+ return self.processor
425
+
426
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
427
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
428
+ # with PEFT is completed.
429
+ is_lora_activated = {
430
+ name: module.lora_layer is not None
431
+ for name, module in self.named_modules()
432
+ if hasattr(module, "lora_layer")
433
+ }
434
+
435
+ # 1. if no layer has a LoRA activated we can return the processor as usual
436
+ if not any(is_lora_activated.values()):
437
+ return self.processor
438
+
439
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
440
+ is_lora_activated.pop("add_k_proj", None)
441
+ is_lora_activated.pop("add_v_proj", None)
442
+ # 2. else it is not possible that only some layers have LoRA activated
443
+ if not all(is_lora_activated.values()):
444
+ raise ValueError(
445
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
446
+ )
447
+
448
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
449
+ non_lora_processor_cls_name = self.processor.__class__.__name__
450
+ lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
451
+
452
+ hidden_size = self.inner_dim
453
+
454
+ # now create a LoRA attention processor from the LoRA layers
455
+ if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
456
+ kwargs = {
457
+ "cross_attention_dim": self.cross_attention_dim,
458
+ "rank": self.to_q.lora_layer.rank,
459
+ "network_alpha": self.to_q.lora_layer.network_alpha,
460
+ "q_rank": self.to_q.lora_layer.rank,
461
+ "q_hidden_size": self.to_q.lora_layer.out_features,
462
+ "k_rank": self.to_k.lora_layer.rank,
463
+ "k_hidden_size": self.to_k.lora_layer.out_features,
464
+ "v_rank": self.to_v.lora_layer.rank,
465
+ "v_hidden_size": self.to_v.lora_layer.out_features,
466
+ "out_rank": self.to_out[0].lora_layer.rank,
467
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
468
+ }
469
+
470
+ if hasattr(self.processor, "attention_op"):
471
+ kwargs["attention_op"] = self.processor.attention_op
472
+
473
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
474
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
475
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
476
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
477
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
478
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
479
+ lora_processor = lora_processor_cls(
480
+ hidden_size,
481
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
482
+ rank=self.to_q.lora_layer.rank,
483
+ network_alpha=self.to_q.lora_layer.network_alpha,
484
+ )
485
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
486
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
487
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
488
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
489
+
490
+ # only save if used
491
+ if self.add_k_proj.lora_layer is not None:
492
+ lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
493
+ lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
494
+ else:
495
+ lora_processor.add_k_proj_lora = None
496
+ lora_processor.add_v_proj_lora = None
497
+ else:
498
+ raise ValueError(f"{lora_processor_cls} does not exist.")
499
+
500
+ return lora_processor
501
+
502
+ class GEGLU(nn.Module):
503
+ def __init__(self, in_features, out_features):
504
+ super(GEGLU, self).__init__()
505
+ self.proj = nn.Linear(in_features, out_features * 2, bias=True)
506
+
507
+ def forward(self, x):
508
+ x_proj = self.proj(x)
509
+ x1, x2 = x_proj.chunk(2, dim=-1)
510
+ return x1 * torch.nn.functional.gelu(x2)
511
+
512
+
513
+ class FeedForward(nn.Module):
514
+ def __init__(self, in_features, out_features):
515
+ super(FeedForward, self).__init__()
516
+
517
+ self.net = nn.ModuleList(
518
+ [
519
+ GEGLU(in_features, out_features * 4),
520
+ nn.Dropout(p=0.0, inplace=False),
521
+ nn.Linear(out_features * 4, out_features, bias=True),
522
+ ]
523
+ )
524
+
525
+ def forward(self, x):
526
+ for layer in self.net:
527
+ x = layer(x)
528
+ return x
529
+
530
+
531
+ class BasicTransformerBlock(nn.Module):
532
+ def __init__(self, hidden_size):
533
+ super(BasicTransformerBlock, self).__init__()
534
+ self.norm1 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
535
+ self.attn1 = Attention(hidden_size)
536
+ self.norm2 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
537
+ self.attn2 = Attention(hidden_size, 2048)
538
+ self.norm3 = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=True)
539
+ self.ff = FeedForward(hidden_size, hidden_size)
540
+
541
+ def forward(self, x, encoder_hidden_states=None):
542
+ residual = x
543
+
544
+ x = self.norm1(x)
545
+ x = self.attn1(x)
546
+ x = x + residual
547
+
548
+ residual = x
549
+
550
+ x = self.norm2(x)
551
+ if encoder_hidden_states is not None:
552
+ x = self.attn2(x, encoder_hidden_states)
553
+ else:
554
+ x = self.attn2(x)
555
+ x = x + residual
556
+
557
+ residual = x
558
+
559
+ x = self.norm3(x)
560
+ x = self.ff(x)
561
+ x = x + residual
562
+ return x
563
+
564
+
565
+ class Transformer2DModel(nn.Module):
566
+ def __init__(self, in_channels, out_channels, n_layers):
567
+ super(Transformer2DModel, self).__init__()
568
+ self.norm = nn.GroupNorm(32, in_channels, eps=1e-06, affine=True)
569
+ self.proj_in = nn.Linear(in_channels, out_channels, bias=True)
570
+ self.transformer_blocks = nn.ModuleList(
571
+ [BasicTransformerBlock(out_channels) for _ in range(n_layers)]
572
+ )
573
+ self.proj_out = nn.Linear(out_channels, out_channels, bias=True)
574
+
575
+ def forward(self, hidden_states, encoder_hidden_states=None):
576
+ batch, _, height, width = hidden_states.shape
577
+ res = hidden_states
578
+ hidden_states = self.norm(hidden_states)
579
+ inner_dim = hidden_states.shape[1]
580
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
581
+ batch, height * width, inner_dim
582
+ )
583
+ hidden_states = self.proj_in(hidden_states)
584
+
585
+ for block in self.transformer_blocks:
586
+ hidden_states = block(hidden_states, encoder_hidden_states)
587
+
588
+ hidden_states = self.proj_out(hidden_states)
589
+ hidden_states = (
590
+ hidden_states.reshape(batch, height, width, inner_dim)
591
+ .permute(0, 3, 1, 2)
592
+ .contiguous()
593
+ )
594
+
595
+ return hidden_states + res
596
+
597
+
598
+ class Downsample2D(nn.Module):
599
+ def __init__(self, in_channels, out_channels):
600
+ super(Downsample2D, self).__init__()
601
+ self.conv = nn.Conv2d(
602
+ in_channels, out_channels, kernel_size=3, stride=2, padding=1
603
+ )
604
+
605
+ def forward(self, x):
606
+ return self.conv(x)
607
+
608
+
609
+ class Upsample2D(nn.Module):
610
+ def __init__(self, in_channels, out_channels):
611
+ super(Upsample2D, self).__init__()
612
+ self.conv = nn.Conv2d(
613
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
614
+ )
615
+
616
+ def forward(self, x):
617
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
618
+ return self.conv(x)
619
+
620
+
621
+ class DownBlock2D(nn.Module):
622
+ def __init__(self, in_channels, out_channels):
623
+ super(DownBlock2D, self).__init__()
624
+ self.resnets = nn.ModuleList(
625
+ [
626
+ ResnetBlock2D(in_channels, out_channels, conv_shortcut=False),
627
+ ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
628
+ ]
629
+ )
630
+ self.downsamplers = nn.ModuleList([Downsample2D(out_channels, out_channels)])
631
+
632
+ def forward(self, hidden_states, temb):
633
+ output_states = []
634
+ for module in self.resnets:
635
+ hidden_states = module(hidden_states, temb)
636
+ output_states.append(hidden_states)
637
+
638
+ hidden_states = self.downsamplers[0](hidden_states)
639
+ output_states.append(hidden_states)
640
+
641
+ return hidden_states, output_states
642
+
643
+
644
+ class CrossAttnDownBlock2D(nn.Module):
645
+ def __init__(self, in_channels, out_channels, n_layers, has_downsamplers=True):
646
+ super(CrossAttnDownBlock2D, self).__init__()
647
+ self.attentions = nn.ModuleList(
648
+ [
649
+ Transformer2DModel(out_channels, out_channels, n_layers),
650
+ Transformer2DModel(out_channels, out_channels, n_layers),
651
+ ]
652
+ )
653
+ self.resnets = nn.ModuleList(
654
+ [
655
+ ResnetBlock2D(in_channels, out_channels),
656
+ ResnetBlock2D(out_channels, out_channels, conv_shortcut=False),
657
+ ]
658
+ )
659
+ self.downsamplers = None
660
+ if has_downsamplers:
661
+ self.downsamplers = nn.ModuleList(
662
+ [Downsample2D(out_channels, out_channels)]
663
+ )
664
+
665
+ def forward(self, hidden_states, temb, encoder_hidden_states):
666
+ output_states = []
667
+ for resnet, attn in zip(self.resnets, self.attentions):
668
+ hidden_states = resnet(hidden_states, temb)
669
+ hidden_states = attn(
670
+ hidden_states,
671
+ encoder_hidden_states=encoder_hidden_states,
672
+ )
673
+ output_states.append(hidden_states)
674
+
675
+ if self.downsamplers is not None:
676
+ hidden_states = self.downsamplers[0](hidden_states)
677
+ output_states.append(hidden_states)
678
+
679
+ return hidden_states, output_states
680
+
681
+
682
+ class CrossAttnUpBlock2D(nn.Module):
683
+ def __init__(self, in_channels, out_channels, prev_output_channel, n_layers):
684
+ super(CrossAttnUpBlock2D, self).__init__()
685
+ self.attentions = nn.ModuleList(
686
+ [
687
+ Transformer2DModel(out_channels, out_channels, n_layers),
688
+ Transformer2DModel(out_channels, out_channels, n_layers),
689
+ Transformer2DModel(out_channels, out_channels, n_layers),
690
+ ]
691
+ )
692
+ self.resnets = nn.ModuleList(
693
+ [
694
+ ResnetBlock2D(prev_output_channel + out_channels, out_channels),
695
+ ResnetBlock2D(2 * out_channels, out_channels),
696
+ ResnetBlock2D(out_channels + in_channels, out_channels),
697
+ ]
698
+ )
699
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, out_channels)])
700
+
701
+ def forward(
702
+ self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states
703
+ ):
704
+ for resnet, attn in zip(self.resnets, self.attentions):
705
+ # pop res hidden states
706
+ res_hidden_states = res_hidden_states_tuple[-1]
707
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
708
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
709
+ hidden_states = resnet(hidden_states, temb)
710
+ hidden_states = attn(
711
+ hidden_states,
712
+ encoder_hidden_states=encoder_hidden_states,
713
+ )
714
+
715
+ if self.upsamplers is not None:
716
+ for upsampler in self.upsamplers:
717
+ hidden_states = upsampler(hidden_states)
718
+
719
+ return hidden_states
720
+
721
+
722
+ class UpBlock2D(nn.Module):
723
+ def __init__(self, in_channels, out_channels, prev_output_channel):
724
+ super(UpBlock2D, self).__init__()
725
+ self.resnets = nn.ModuleList(
726
+ [
727
+ ResnetBlock2D(out_channels + prev_output_channel, out_channels),
728
+ ResnetBlock2D(out_channels * 2, out_channels),
729
+ ResnetBlock2D(out_channels + in_channels, out_channels),
730
+ ]
731
+ )
732
+
733
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
734
+
735
+ is_freeu_enabled = (
736
+ getattr(self, "s1", None)
737
+ and getattr(self, "s2", None)
738
+ and getattr(self, "b1", None)
739
+ and getattr(self, "b2", None)
740
+ and getattr(self, "resolution_idx", None)
741
+ )
742
+
743
+ for resnet in self.resnets:
744
+ res_hidden_states = res_hidden_states_tuple[-1]
745
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
746
+
747
+
748
+ if is_freeu_enabled:
749
+ hidden_states, res_hidden_states = apply_freeu(
750
+ self.resolution_idx,
751
+ hidden_states,
752
+ res_hidden_states,
753
+ s1=self.s1,
754
+ s2=self.s2,
755
+ b1=self.b1,
756
+ b2=self.b2,
757
+ )
758
+
759
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
760
+ hidden_states = resnet(hidden_states, temb)
761
+
762
+ return hidden_states
763
+
764
+ class UNetMidBlock2DCrossAttn(nn.Module):
765
+ def __init__(self, in_features):
766
+ super(UNetMidBlock2DCrossAttn, self).__init__()
767
+ self.attentions = nn.ModuleList(
768
+ [Transformer2DModel(in_features, in_features, n_layers=10)]
769
+ )
770
+ self.resnets = nn.ModuleList(
771
+ [
772
+ ResnetBlock2D(in_features, in_features, conv_shortcut=False),
773
+ ResnetBlock2D(in_features, in_features, conv_shortcut=False),
774
+ ]
775
+ )
776
+
777
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
778
+ hidden_states = self.resnets[0](hidden_states, temb)
779
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
780
+ hidden_states = attn(
781
+ hidden_states,
782
+ encoder_hidden_states=encoder_hidden_states,
783
+ )
784
+ hidden_states = resnet(hidden_states, temb)
785
+
786
+ return hidden_states
787
+
788
+
789
+ class UNet2DConditionModel(nn.Module):
790
+ def __init__(self):
791
+ super(UNet2DConditionModel, self).__init__()
792
+
793
+ # This is needed to imitate huggingface config behavior
794
+ # has nothing to do with the model itself
795
+ # remove this if you don't use diffuser's pipeline
796
+ self.config = namedtuple(
797
+ "config", "in_channels addition_time_embed_dim sample_size"
798
+ )
799
+ self.config.in_channels = 4
800
+ self.config.addition_time_embed_dim = 256
801
+ self.config.sample_size = 128
802
+
803
+ self.conv_in = nn.Conv2d(4, 320, kernel_size=3, stride=1, padding=1)
804
+ self.time_proj = Timesteps()
805
+ self.time_embedding = TimestepEmbedding(in_features=320, out_features=1280)
806
+ self.add_time_proj = Timesteps(256)
807
+ self.add_embedding = TimestepEmbedding(in_features=2816, out_features=1280)
808
+ self.down_blocks = nn.ModuleList(
809
+ [
810
+ DownBlock2D(in_channels=320, out_channels=320),
811
+ CrossAttnDownBlock2D(in_channels=320, out_channels=640, n_layers=2),
812
+ CrossAttnDownBlock2D(
813
+ in_channels=640,
814
+ out_channels=1280,
815
+ n_layers=10,
816
+ has_downsamplers=False,
817
+ ),
818
+ ]
819
+ )
820
+ self.up_blocks = nn.ModuleList(
821
+ [
822
+ CrossAttnUpBlock2D(
823
+ in_channels=640,
824
+ out_channels=1280,
825
+ prev_output_channel=1280,
826
+ n_layers=10,
827
+ ),
828
+ CrossAttnUpBlock2D(
829
+ in_channels=320,
830
+ out_channels=640,
831
+ prev_output_channel=1280,
832
+ n_layers=2,
833
+ ),
834
+ UpBlock2D(in_channels=320, out_channels=320, prev_output_channel=640),
835
+ ]
836
+ )
837
+ self.mid_block = UNetMidBlock2DCrossAttn(1280)
838
+ self.conv_norm_out = nn.GroupNorm(32, 320, eps=1e-05, affine=True)
839
+ self.conv_act = nn.SiLU()
840
+ self.conv_out = nn.Conv2d(320, 4, kernel_size=3, stride=1, padding=1)
841
+
842
+ def forward(
843
+ self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, **kwargs
844
+ ):
845
+ # Implement the forward pass through the model
846
+ timesteps = timesteps.expand(sample.shape[0])
847
+ t_emb = self.time_proj(timesteps).to(dtype=sample.dtype)
848
+
849
+ emb = self.time_embedding(t_emb)
850
+
851
+ text_embeds = added_cond_kwargs.get("text_embeds")
852
+ time_ids = added_cond_kwargs.get("time_ids")
853
+
854
+ time_embeds = self.add_time_proj(time_ids.flatten())
855
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
856
+
857
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
858
+ add_embeds = add_embeds.to(emb.dtype)
859
+ aug_emb = self.add_embedding(add_embeds)
860
+
861
+ emb = emb + aug_emb
862
+
863
+ sample = self.conv_in(sample)
864
+
865
+ # 3. down
866
+ s0 = sample
867
+ sample, [s1, s2, s3] = self.down_blocks[0](
868
+ sample,
869
+ temb=emb,
870
+ )
871
+
872
+ sample, [s4, s5, s6] = self.down_blocks[1](
873
+ sample,
874
+ temb=emb,
875
+ encoder_hidden_states=encoder_hidden_states,
876
+ )
877
+
878
+ sample, [s7, s8] = self.down_blocks[2](
879
+ sample,
880
+ temb=emb,
881
+ encoder_hidden_states=encoder_hidden_states,
882
+ )
883
+
884
+ # 4. mid
885
+ sample = self.mid_block(
886
+ sample, emb, encoder_hidden_states=encoder_hidden_states
887
+ )
888
+
889
+ # 5. up
890
+ sample = self.up_blocks[0](
891
+ hidden_states=sample,
892
+ temb=emb,
893
+ res_hidden_states_tuple=[s6, s7, s8],
894
+ encoder_hidden_states=encoder_hidden_states,
895
+ )
896
+
897
+ sample = self.up_blocks[1](
898
+ hidden_states=sample,
899
+ temb=emb,
900
+ res_hidden_states_tuple=[s3, s4, s5],
901
+ encoder_hidden_states=encoder_hidden_states,
902
+ )
903
+
904
+ sample = self.up_blocks[2](
905
+ hidden_states=sample,
906
+ temb=emb,
907
+ res_hidden_states_tuple=[s0, s1, s2],
908
+ )
909
+
910
+ # 6. post-process
911
+ sample = self.conv_norm_out(sample)
912
+ sample = self.conv_act(sample)
913
+ sample = self.conv_out(sample)
914
+
915
+ return [sample]
module/unet/unet_2d_ZeroSFT.py ADDED
@@ -0,0 +1,1397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copy from diffusers.models.unets.unet_2d_condition.py
2
+
3
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.utils.checkpoint
22
+
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
25
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
26
+ from diffusers.models.activations import get_activation
27
+ from diffusers.models.attention_processor import (
28
+ ADDED_KV_ATTENTION_PROCESSORS,
29
+ CROSS_ATTENTION_PROCESSORS,
30
+ Attention,
31
+ AttentionProcessor,
32
+ AttnAddedKVProcessor,
33
+ AttnProcessor,
34
+ )
35
+ from diffusers.models.embeddings import (
36
+ GaussianFourierProjection,
37
+ GLIGENTextBoundingboxProjection,
38
+ ImageHintTimeEmbedding,
39
+ ImageProjection,
40
+ ImageTimeEmbedding,
41
+ TextImageProjection,
42
+ TextImageTimeEmbedding,
43
+ TextTimeEmbedding,
44
+ TimestepEmbedding,
45
+ Timesteps,
46
+ )
47
+ from diffusers.models.modeling_utils import ModelMixin
48
+ from .unet_2d_ZeroSFT_blocks import (
49
+ get_down_block,
50
+ get_mid_block,
51
+ get_up_block,
52
+ )
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ def zero_module(module):
59
+ for p in module.parameters():
60
+ nn.init.zeros_(p)
61
+ return module
62
+
63
+
64
+ class ZeroConv(nn.Module):
65
+ def __init__(self, label_nc, norm_nc, mask=False):
66
+ super().__init__()
67
+ self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
68
+ self.mask = mask
69
+
70
+ def forward(self, c, h, h_ori=None):
71
+ # with torch.cuda.amp.autocast(enabled=False, dtype=torch.float32):
72
+ if not self.mask:
73
+ h = h + self.zero_conv(c)
74
+ else:
75
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
76
+ if h_ori is not None:
77
+ h = torch.cat([h_ori, h], dim=1)
78
+ return h
79
+
80
+
81
+ class ZeroSFT(nn.Module):
82
+ def __init__(self, label_nc, norm_nc, concat_channels=0, norm=True, mask=False):
83
+ super().__init__()
84
+
85
+ # param_free_norm_type = str(parsed.group(1))
86
+ ks = 3
87
+ pw = ks // 2
88
+
89
+ self.mask = mask
90
+ self.norm = norm
91
+ self.pre_concat = bool(concat_channels != 0)
92
+ if self.norm:
93
+ self.param_free_norm = torch.nn.GroupNorm(num_groups=32, num_channels=norm_nc + concat_channels)
94
+ else:
95
+ self.param_free_norm = nn.Identity()
96
+
97
+ nhidden = 128
98
+
99
+ self.mlp_shared = nn.Sequential(
100
+ nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
101
+ nn.SiLU()
102
+ )
103
+ self.zero_mul = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
104
+ self.zero_add = zero_module(nn.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw))
105
+
106
+ self.zero_conv = zero_module(nn.Conv2d(label_nc, norm_nc, 1, 1, 0))
107
+
108
+ def forward(self, down_block_res_samples, h_ori=None, control_scale=1.0, mask=False):
109
+ mask = mask or self.mask
110
+ assert mask is False
111
+ if self.pre_concat:
112
+ assert h_ori is not None
113
+
114
+ c,h = down_block_res_samples
115
+ if h_ori is not None:
116
+ h_raw = torch.cat([h_ori, h], dim=1)
117
+ else:
118
+ h_raw = h
119
+
120
+ if self.mask:
121
+ h = h + self.zero_conv(c) * torch.zeros_like(h)
122
+ else:
123
+ h = h + self.zero_conv(c)
124
+ if h_ori is not None and self.pre_concat:
125
+ h = torch.cat([h_ori, h], dim=1)
126
+ actv = self.mlp_shared(c)
127
+ gamma = self.zero_mul(actv)
128
+ beta = self.zero_add(actv)
129
+ if self.mask:
130
+ gamma = gamma * torch.zeros_like(gamma)
131
+ beta = beta * torch.zeros_like(beta)
132
+ # h = h + self.param_free_norm(h) * gamma + beta
133
+ h = self.param_free_norm(h) * (gamma + 1) + beta
134
+ if h_ori is not None and not self.pre_concat:
135
+ h = torch.cat([h_ori, h], dim=1)
136
+ return h * control_scale + h_raw * (1 - control_scale)
137
+
138
+
139
+ @dataclass
140
+ class UNet2DConditionOutput(BaseOutput):
141
+ """
142
+ The output of [`UNet2DConditionModel`].
143
+
144
+ Args:
145
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
146
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
147
+ """
148
+
149
+ sample: torch.FloatTensor = None
150
+
151
+
152
+ class UNet2DZeroSFTModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
153
+ r"""
154
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
155
+ shaped output.
156
+
157
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
158
+ for all models (such as downloading or saving).
159
+
160
+ Parameters:
161
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
162
+ Height and width of input/output sample.
163
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
164
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
165
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
166
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
167
+ Whether to flip the sin to cos in the time embedding.
168
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
169
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
170
+ The tuple of downsample blocks to use.
171
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
172
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
173
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
174
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
175
+ The tuple of upsample blocks to use.
176
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
177
+ Whether to include self-attention in the basic transformer blocks, see
178
+ [`~models.attention.BasicTransformerBlock`].
179
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
180
+ The tuple of output channels for each block.
181
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
182
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
183
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
184
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
185
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
186
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
187
+ If `None`, normalization and activation layers is skipped in post-processing.
188
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
189
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
190
+ The dimension of the cross attention features.
191
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
192
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
193
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
194
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
195
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
196
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
197
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
198
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
199
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
200
+ encoder_hid_dim (`int`, *optional*, defaults to None):
201
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
202
+ dimension to `cross_attention_dim`.
203
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
204
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
205
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
206
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
207
+ num_attention_heads (`int`, *optional*):
208
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
209
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
210
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
211
+ class_embed_type (`str`, *optional*, defaults to `None`):
212
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
213
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
214
+ addition_embed_type (`str`, *optional*, defaults to `None`):
215
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
216
+ "text". "text" will use the `TextTimeEmbedding` layer.
217
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
218
+ Dimension for the timestep embeddings.
219
+ num_class_embeds (`int`, *optional*, defaults to `None`):
220
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
221
+ class conditioning with `class_embed_type` equal to `None`.
222
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
223
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
224
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
225
+ An optional override for the dimension of the projected time embedding.
226
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
227
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
228
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
229
+ timestep_post_act (`str`, *optional*, defaults to `None`):
230
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
231
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
232
+ The dimension of `cond_proj` layer in the timestep embedding.
233
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
234
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
235
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
236
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
237
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
238
+ embeddings with the class embeddings.
239
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
240
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
241
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
242
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
243
+ otherwise.
244
+ """
245
+
246
+ _supports_gradient_checkpointing = True
247
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
248
+
249
+ @register_to_config
250
+ def __init__(
251
+ self,
252
+ sample_size: Optional[int] = None,
253
+ in_channels: int = 4,
254
+ out_channels: int = 4,
255
+ center_input_sample: bool = False,
256
+ flip_sin_to_cos: bool = True,
257
+ freq_shift: int = 0,
258
+ down_block_types: Tuple[str] = (
259
+ "CrossAttnDownBlock2D",
260
+ "CrossAttnDownBlock2D",
261
+ "CrossAttnDownBlock2D",
262
+ "DownBlock2D",
263
+ ),
264
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
265
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
266
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
267
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
268
+ layers_per_block: Union[int, Tuple[int]] = 2,
269
+ downsample_padding: int = 1,
270
+ mid_block_scale_factor: float = 1,
271
+ dropout: float = 0.0,
272
+ act_fn: str = "silu",
273
+ norm_num_groups: Optional[int] = 32,
274
+ norm_eps: float = 1e-5,
275
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
276
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
277
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
278
+ encoder_hid_dim: Optional[int] = None,
279
+ encoder_hid_dim_type: Optional[str] = None,
280
+ attention_head_dim: Union[int, Tuple[int]] = 8,
281
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
282
+ dual_cross_attention: bool = False,
283
+ use_linear_projection: bool = False,
284
+ class_embed_type: Optional[str] = None,
285
+ addition_embed_type: Optional[str] = None,
286
+ addition_time_embed_dim: Optional[int] = None,
287
+ num_class_embeds: Optional[int] = None,
288
+ upcast_attention: bool = False,
289
+ resnet_time_scale_shift: str = "default",
290
+ resnet_skip_time_act: bool = False,
291
+ resnet_out_scale_factor: float = 1.0,
292
+ time_embedding_type: str = "positional",
293
+ time_embedding_dim: Optional[int] = None,
294
+ time_embedding_act_fn: Optional[str] = None,
295
+ timestep_post_act: Optional[str] = None,
296
+ time_cond_proj_dim: Optional[int] = None,
297
+ conv_in_kernel: int = 3,
298
+ conv_out_kernel: int = 3,
299
+ projection_class_embeddings_input_dim: Optional[int] = None,
300
+ attention_type: str = "default",
301
+ class_embeddings_concat: bool = False,
302
+ mid_block_only_cross_attention: Optional[bool] = None,
303
+ cross_attention_norm: Optional[str] = None,
304
+ addition_embed_type_num_heads: int = 64,
305
+ ):
306
+ super().__init__()
307
+
308
+ self.sample_size = sample_size
309
+
310
+ if num_attention_heads is not None:
311
+ raise ValueError(
312
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
313
+ )
314
+
315
+ # If `num_attention_heads` is not defined (which is the case for most models)
316
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
317
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
318
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
319
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
320
+ # which is why we correct for the naming here.
321
+ num_attention_heads = num_attention_heads or attention_head_dim
322
+
323
+ # Check inputs
324
+ self._check_config(
325
+ down_block_types=down_block_types,
326
+ up_block_types=up_block_types,
327
+ only_cross_attention=only_cross_attention,
328
+ block_out_channels=block_out_channels,
329
+ layers_per_block=layers_per_block,
330
+ cross_attention_dim=cross_attention_dim,
331
+ transformer_layers_per_block=transformer_layers_per_block,
332
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
333
+ attention_head_dim=attention_head_dim,
334
+ num_attention_heads=num_attention_heads,
335
+ )
336
+
337
+ # input
338
+ conv_in_padding = (conv_in_kernel - 1) // 2
339
+ self.conv_in = nn.Conv2d(
340
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
341
+ )
342
+
343
+ # time
344
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
345
+ time_embedding_type,
346
+ block_out_channels=block_out_channels,
347
+ flip_sin_to_cos=flip_sin_to_cos,
348
+ freq_shift=freq_shift,
349
+ time_embedding_dim=time_embedding_dim,
350
+ )
351
+
352
+ self.time_embedding = TimestepEmbedding(
353
+ timestep_input_dim,
354
+ time_embed_dim,
355
+ act_fn=act_fn,
356
+ post_act_fn=timestep_post_act,
357
+ cond_proj_dim=time_cond_proj_dim,
358
+ )
359
+
360
+ self._set_encoder_hid_proj(
361
+ encoder_hid_dim_type,
362
+ cross_attention_dim=cross_attention_dim,
363
+ encoder_hid_dim=encoder_hid_dim,
364
+ )
365
+
366
+ # class embedding
367
+ self._set_class_embedding(
368
+ class_embed_type,
369
+ act_fn=act_fn,
370
+ num_class_embeds=num_class_embeds,
371
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
372
+ time_embed_dim=time_embed_dim,
373
+ timestep_input_dim=timestep_input_dim,
374
+ )
375
+
376
+ self._set_add_embedding(
377
+ addition_embed_type,
378
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
379
+ addition_time_embed_dim=addition_time_embed_dim,
380
+ cross_attention_dim=cross_attention_dim,
381
+ encoder_hid_dim=encoder_hid_dim,
382
+ flip_sin_to_cos=flip_sin_to_cos,
383
+ freq_shift=freq_shift,
384
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
385
+ time_embed_dim=time_embed_dim,
386
+ )
387
+
388
+ if time_embedding_act_fn is None:
389
+ self.time_embed_act = None
390
+ else:
391
+ self.time_embed_act = get_activation(time_embedding_act_fn)
392
+
393
+ self.down_blocks = nn.ModuleList([])
394
+ self.up_blocks = nn.ModuleList([])
395
+
396
+ if isinstance(only_cross_attention, bool):
397
+ if mid_block_only_cross_attention is None:
398
+ mid_block_only_cross_attention = only_cross_attention
399
+
400
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
401
+
402
+ if mid_block_only_cross_attention is None:
403
+ mid_block_only_cross_attention = False
404
+
405
+ if isinstance(num_attention_heads, int):
406
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
407
+
408
+ if isinstance(attention_head_dim, int):
409
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
410
+
411
+ if isinstance(cross_attention_dim, int):
412
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
413
+
414
+ if isinstance(layers_per_block, int):
415
+ layers_per_block = [layers_per_block] * len(down_block_types)
416
+
417
+ if isinstance(transformer_layers_per_block, int):
418
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
419
+
420
+ if class_embeddings_concat:
421
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
422
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
423
+ # regular time embeddings
424
+ blocks_time_embed_dim = time_embed_dim * 2
425
+ else:
426
+ blocks_time_embed_dim = time_embed_dim
427
+
428
+ # down
429
+ output_channel = block_out_channels[0]
430
+ for i, down_block_type in enumerate(down_block_types):
431
+ input_channel = output_channel
432
+ output_channel = block_out_channels[i]
433
+ is_final_block = i == len(block_out_channels) - 1
434
+
435
+ down_block = get_down_block(
436
+ down_block_type,
437
+ num_layers=layers_per_block[i],
438
+ transformer_layers_per_block=transformer_layers_per_block[i],
439
+ in_channels=input_channel,
440
+ out_channels=output_channel,
441
+ temb_channels=blocks_time_embed_dim,
442
+ add_downsample=not is_final_block,
443
+ resnet_eps=norm_eps,
444
+ resnet_act_fn=act_fn,
445
+ resnet_groups=norm_num_groups,
446
+ cross_attention_dim=cross_attention_dim[i],
447
+ num_attention_heads=num_attention_heads[i],
448
+ downsample_padding=downsample_padding,
449
+ dual_cross_attention=dual_cross_attention,
450
+ use_linear_projection=use_linear_projection,
451
+ only_cross_attention=only_cross_attention[i],
452
+ upcast_attention=upcast_attention,
453
+ resnet_time_scale_shift=resnet_time_scale_shift,
454
+ attention_type=attention_type,
455
+ resnet_skip_time_act=resnet_skip_time_act,
456
+ resnet_out_scale_factor=resnet_out_scale_factor,
457
+ cross_attention_norm=cross_attention_norm,
458
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
459
+ dropout=dropout,
460
+ )
461
+ self.down_blocks.append(down_block)
462
+
463
+ # mid
464
+ self.mid_block = get_mid_block(
465
+ mid_block_type,
466
+ temb_channels=blocks_time_embed_dim,
467
+ in_channels=block_out_channels[-1],
468
+ resnet_eps=norm_eps,
469
+ resnet_act_fn=act_fn,
470
+ resnet_groups=norm_num_groups,
471
+ output_scale_factor=mid_block_scale_factor,
472
+ transformer_layers_per_block=transformer_layers_per_block[-1],
473
+ num_attention_heads=num_attention_heads[-1],
474
+ cross_attention_dim=cross_attention_dim[-1],
475
+ dual_cross_attention=dual_cross_attention,
476
+ use_linear_projection=use_linear_projection,
477
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
478
+ upcast_attention=upcast_attention,
479
+ resnet_time_scale_shift=resnet_time_scale_shift,
480
+ attention_type=attention_type,
481
+ resnet_skip_time_act=resnet_skip_time_act,
482
+ cross_attention_norm=cross_attention_norm,
483
+ attention_head_dim=attention_head_dim[-1],
484
+ dropout=dropout,
485
+ )
486
+ self.mid_zero_SFT = ZeroSFT(block_out_channels[-1],block_out_channels[-1],0)
487
+
488
+ # count how many layers upsample the images
489
+ self.num_upsamplers = 0
490
+
491
+ # up
492
+ reversed_block_out_channels = list(reversed(block_out_channels))
493
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
494
+ reversed_layers_per_block = list(reversed(layers_per_block))
495
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
496
+ reversed_transformer_layers_per_block = (
497
+ list(reversed(transformer_layers_per_block))
498
+ if reverse_transformer_layers_per_block is None
499
+ else reverse_transformer_layers_per_block
500
+ )
501
+ only_cross_attention = list(reversed(only_cross_attention))
502
+
503
+ output_channel = reversed_block_out_channels[0]
504
+ for i, up_block_type in enumerate(up_block_types):
505
+ is_final_block = i == len(block_out_channels) - 1
506
+
507
+ prev_output_channel = output_channel
508
+ output_channel = reversed_block_out_channels[i]
509
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
510
+
511
+ # add upsample block for all BUT final layer
512
+ if not is_final_block:
513
+ add_upsample = True
514
+ self.num_upsamplers += 1
515
+ else:
516
+ add_upsample = False
517
+
518
+ up_block = get_up_block(
519
+ up_block_type,
520
+ num_layers=reversed_layers_per_block[i] + 1,
521
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
522
+ in_channels=input_channel,
523
+ out_channels=output_channel,
524
+ prev_output_channel=prev_output_channel,
525
+ temb_channels=blocks_time_embed_dim,
526
+ add_upsample=add_upsample,
527
+ resnet_eps=norm_eps,
528
+ resnet_act_fn=act_fn,
529
+ resolution_idx=i,
530
+ resnet_groups=norm_num_groups,
531
+ cross_attention_dim=reversed_cross_attention_dim[i],
532
+ num_attention_heads=reversed_num_attention_heads[i],
533
+ dual_cross_attention=dual_cross_attention,
534
+ use_linear_projection=use_linear_projection,
535
+ only_cross_attention=only_cross_attention[i],
536
+ upcast_attention=upcast_attention,
537
+ resnet_time_scale_shift=resnet_time_scale_shift,
538
+ attention_type=attention_type,
539
+ resnet_skip_time_act=resnet_skip_time_act,
540
+ resnet_out_scale_factor=resnet_out_scale_factor,
541
+ cross_attention_norm=cross_attention_norm,
542
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
543
+ dropout=dropout,
544
+ )
545
+ self.up_blocks.append(up_block)
546
+ prev_output_channel = output_channel
547
+
548
+ # out
549
+ if norm_num_groups is not None:
550
+ self.conv_norm_out = nn.GroupNorm(
551
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
552
+ )
553
+
554
+ self.conv_act = get_activation(act_fn)
555
+
556
+ else:
557
+ self.conv_norm_out = None
558
+ self.conv_act = None
559
+
560
+ conv_out_padding = (conv_out_kernel - 1) // 2
561
+ self.conv_out = nn.Conv2d(
562
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
563
+ )
564
+
565
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
566
+
567
+ def _check_config(
568
+ self,
569
+ down_block_types: Tuple[str],
570
+ up_block_types: Tuple[str],
571
+ only_cross_attention: Union[bool, Tuple[bool]],
572
+ block_out_channels: Tuple[int],
573
+ layers_per_block: Union[int, Tuple[int]],
574
+ cross_attention_dim: Union[int, Tuple[int]],
575
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
576
+ reverse_transformer_layers_per_block: bool,
577
+ attention_head_dim: int,
578
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
579
+ ):
580
+ if len(down_block_types) != len(up_block_types):
581
+ raise ValueError(
582
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
583
+ )
584
+
585
+ if len(block_out_channels) != len(down_block_types):
586
+ raise ValueError(
587
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
588
+ )
589
+
590
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
591
+ raise ValueError(
592
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
593
+ )
594
+
595
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
596
+ raise ValueError(
597
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
598
+ )
599
+
600
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
601
+ raise ValueError(
602
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
603
+ )
604
+
605
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
606
+ raise ValueError(
607
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
608
+ )
609
+
610
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
611
+ raise ValueError(
612
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
613
+ )
614
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
615
+ for layer_number_per_block in transformer_layers_per_block:
616
+ if isinstance(layer_number_per_block, list):
617
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
618
+
619
+ def _set_time_proj(
620
+ self,
621
+ time_embedding_type: str,
622
+ block_out_channels: int,
623
+ flip_sin_to_cos: bool,
624
+ freq_shift: float,
625
+ time_embedding_dim: int,
626
+ ) -> Tuple[int, int]:
627
+ if time_embedding_type == "fourier":
628
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
629
+ if time_embed_dim % 2 != 0:
630
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
631
+ self.time_proj = GaussianFourierProjection(
632
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
633
+ )
634
+ timestep_input_dim = time_embed_dim
635
+ elif time_embedding_type == "positional":
636
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
637
+
638
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
639
+ timestep_input_dim = block_out_channels[0]
640
+ else:
641
+ raise ValueError(
642
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
643
+ )
644
+
645
+ return time_embed_dim, timestep_input_dim
646
+
647
+ def _set_encoder_hid_proj(
648
+ self,
649
+ encoder_hid_dim_type: Optional[str],
650
+ cross_attention_dim: Union[int, Tuple[int]],
651
+ encoder_hid_dim: Optional[int],
652
+ ):
653
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
654
+ encoder_hid_dim_type = "text_proj"
655
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
656
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
657
+
658
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
659
+ raise ValueError(
660
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
661
+ )
662
+
663
+ if encoder_hid_dim_type == "text_proj":
664
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
665
+ elif encoder_hid_dim_type == "text_image_proj":
666
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
667
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
668
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
669
+ self.encoder_hid_proj = TextImageProjection(
670
+ text_embed_dim=encoder_hid_dim,
671
+ image_embed_dim=cross_attention_dim,
672
+ cross_attention_dim=cross_attention_dim,
673
+ )
674
+ elif encoder_hid_dim_type == "image_proj":
675
+ # Kandinsky 2.2
676
+ self.encoder_hid_proj = ImageProjection(
677
+ image_embed_dim=encoder_hid_dim,
678
+ cross_attention_dim=cross_attention_dim,
679
+ )
680
+ elif encoder_hid_dim_type is not None:
681
+ raise ValueError(
682
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
683
+ )
684
+ else:
685
+ self.encoder_hid_proj = None
686
+
687
+ def _set_class_embedding(
688
+ self,
689
+ class_embed_type: Optional[str],
690
+ act_fn: str,
691
+ num_class_embeds: Optional[int],
692
+ projection_class_embeddings_input_dim: Optional[int],
693
+ time_embed_dim: int,
694
+ timestep_input_dim: int,
695
+ ):
696
+ if class_embed_type is None and num_class_embeds is not None:
697
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
698
+ elif class_embed_type == "timestep":
699
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
700
+ elif class_embed_type == "identity":
701
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
702
+ elif class_embed_type == "projection":
703
+ if projection_class_embeddings_input_dim is None:
704
+ raise ValueError(
705
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
706
+ )
707
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
708
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
709
+ # 2. it projects from an arbitrary input dimension.
710
+ #
711
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
712
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
713
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
714
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
715
+ elif class_embed_type == "simple_projection":
716
+ if projection_class_embeddings_input_dim is None:
717
+ raise ValueError(
718
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
719
+ )
720
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
721
+ else:
722
+ self.class_embedding = None
723
+
724
+ def _set_add_embedding(
725
+ self,
726
+ addition_embed_type: str,
727
+ addition_embed_type_num_heads: int,
728
+ addition_time_embed_dim: Optional[int],
729
+ flip_sin_to_cos: bool,
730
+ freq_shift: float,
731
+ cross_attention_dim: Optional[int],
732
+ encoder_hid_dim: Optional[int],
733
+ projection_class_embeddings_input_dim: Optional[int],
734
+ time_embed_dim: int,
735
+ ):
736
+ if addition_embed_type == "text":
737
+ if encoder_hid_dim is not None:
738
+ text_time_embedding_from_dim = encoder_hid_dim
739
+ else:
740
+ text_time_embedding_from_dim = cross_attention_dim
741
+
742
+ self.add_embedding = TextTimeEmbedding(
743
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
744
+ )
745
+ elif addition_embed_type == "text_image":
746
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
747
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
748
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
749
+ self.add_embedding = TextImageTimeEmbedding(
750
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
751
+ )
752
+ elif addition_embed_type == "text_time":
753
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
754
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
755
+ elif addition_embed_type == "image":
756
+ # Kandinsky 2.2
757
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
758
+ elif addition_embed_type == "image_hint":
759
+ # Kandinsky 2.2 ControlNet
760
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
761
+ elif addition_embed_type is not None:
762
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
763
+
764
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
765
+ if attention_type in ["gated", "gated-text-image"]:
766
+ positive_len = 768
767
+ if isinstance(cross_attention_dim, int):
768
+ positive_len = cross_attention_dim
769
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
770
+ positive_len = cross_attention_dim[0]
771
+
772
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
773
+ self.position_net = GLIGENTextBoundingboxProjection(
774
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
775
+ )
776
+
777
+ @property
778
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
779
+ r"""
780
+ Returns:
781
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
782
+ indexed by its weight name.
783
+ """
784
+ # set recursively
785
+ processors = {}
786
+
787
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
788
+ if hasattr(module, "get_processor"):
789
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
790
+
791
+ for sub_name, child in module.named_children():
792
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
793
+
794
+ return processors
795
+
796
+ for name, module in self.named_children():
797
+ fn_recursive_add_processors(name, module, processors)
798
+
799
+ return processors
800
+
801
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
802
+ r"""
803
+ Sets the attention processor to use to compute attention.
804
+
805
+ Parameters:
806
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
807
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
808
+ for **all** `Attention` layers.
809
+
810
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
811
+ processor. This is strongly recommended when setting trainable attention processors.
812
+
813
+ """
814
+ count = len(self.attn_processors.keys())
815
+
816
+ if isinstance(processor, dict) and len(processor) != count:
817
+ raise ValueError(
818
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
819
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
820
+ )
821
+
822
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
823
+ if hasattr(module, "set_processor"):
824
+ if not isinstance(processor, dict):
825
+ module.set_processor(processor)
826
+ else:
827
+ module.set_processor(processor.pop(f"{name}.processor"))
828
+
829
+ for sub_name, child in module.named_children():
830
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
831
+
832
+ for name, module in self.named_children():
833
+ fn_recursive_attn_processor(name, module, processor)
834
+
835
+ def set_default_attn_processor(self):
836
+ """
837
+ Disables custom attention processors and sets the default attention implementation.
838
+ """
839
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
840
+ processor = AttnAddedKVProcessor()
841
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
842
+ processor = AttnProcessor()
843
+ else:
844
+ raise ValueError(
845
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
846
+ )
847
+
848
+ self.set_attn_processor(processor)
849
+
850
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
851
+ r"""
852
+ Enable sliced attention computation.
853
+
854
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
855
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
856
+
857
+ Args:
858
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
859
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
860
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
861
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
862
+ must be a multiple of `slice_size`.
863
+ """
864
+ sliceable_head_dims = []
865
+
866
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
867
+ if hasattr(module, "set_attention_slice"):
868
+ sliceable_head_dims.append(module.sliceable_head_dim)
869
+
870
+ for child in module.children():
871
+ fn_recursive_retrieve_sliceable_dims(child)
872
+
873
+ # retrieve number of attention layers
874
+ for module in self.children():
875
+ fn_recursive_retrieve_sliceable_dims(module)
876
+
877
+ num_sliceable_layers = len(sliceable_head_dims)
878
+
879
+ if slice_size == "auto":
880
+ # half the attention head size is usually a good trade-off between
881
+ # speed and memory
882
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
883
+ elif slice_size == "max":
884
+ # make smallest slice possible
885
+ slice_size = num_sliceable_layers * [1]
886
+
887
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
888
+
889
+ if len(slice_size) != len(sliceable_head_dims):
890
+ raise ValueError(
891
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
892
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
893
+ )
894
+
895
+ for i in range(len(slice_size)):
896
+ size = slice_size[i]
897
+ dim = sliceable_head_dims[i]
898
+ if size is not None and size > dim:
899
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
900
+
901
+ # Recursively walk through all the children.
902
+ # Any children which exposes the set_attention_slice method
903
+ # gets the message
904
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
905
+ if hasattr(module, "set_attention_slice"):
906
+ module.set_attention_slice(slice_size.pop())
907
+
908
+ for child in module.children():
909
+ fn_recursive_set_attention_slice(child, slice_size)
910
+
911
+ reversed_slice_size = list(reversed(slice_size))
912
+ for module in self.children():
913
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
914
+
915
+ def _set_gradient_checkpointing(self, module, value=False):
916
+ if hasattr(module, "gradient_checkpointing"):
917
+ module.gradient_checkpointing = value
918
+
919
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
920
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
921
+
922
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
923
+
924
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
925
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
926
+
927
+ Args:
928
+ s1 (`float`):
929
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
930
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
931
+ s2 (`float`):
932
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
933
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
934
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
935
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
936
+ """
937
+ for i, upsample_block in enumerate(self.up_blocks):
938
+ setattr(upsample_block, "s1", s1)
939
+ setattr(upsample_block, "s2", s2)
940
+ setattr(upsample_block, "b1", b1)
941
+ setattr(upsample_block, "b2", b2)
942
+
943
+ def disable_freeu(self):
944
+ """Disables the FreeU mechanism."""
945
+ freeu_keys = {"s1", "s2", "b1", "b2"}
946
+ for i, upsample_block in enumerate(self.up_blocks):
947
+ for k in freeu_keys:
948
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
949
+ setattr(upsample_block, k, None)
950
+
951
+ def fuse_qkv_projections(self):
952
+ """
953
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
954
+ are fused. For cross-attention modules, key and value projection matrices are fused.
955
+
956
+ <Tip warning={true}>
957
+
958
+ This API is 🧪 experimental.
959
+
960
+ </Tip>
961
+ """
962
+ self.original_attn_processors = None
963
+
964
+ for _, attn_processor in self.attn_processors.items():
965
+ if "Added" in str(attn_processor.__class__.__name__):
966
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
967
+
968
+ self.original_attn_processors = self.attn_processors
969
+
970
+ for module in self.modules():
971
+ if isinstance(module, Attention):
972
+ module.fuse_projections(fuse=True)
973
+
974
+ def unfuse_qkv_projections(self):
975
+ """Disables the fused QKV projection if enabled.
976
+
977
+ <Tip warning={true}>
978
+
979
+ This API is 🧪 experimental.
980
+
981
+ </Tip>
982
+
983
+ """
984
+ if self.original_attn_processors is not None:
985
+ self.set_attn_processor(self.original_attn_processors)
986
+
987
+ def unload_lora(self):
988
+ """Unloads LoRA weights."""
989
+ deprecate(
990
+ "unload_lora",
991
+ "0.28.0",
992
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
993
+ )
994
+ for module in self.modules():
995
+ if hasattr(module, "set_lora_layer"):
996
+ module.set_lora_layer(None)
997
+
998
+ def get_time_embed(
999
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
1000
+ ) -> Optional[torch.Tensor]:
1001
+ timesteps = timestep
1002
+ if not torch.is_tensor(timesteps):
1003
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1004
+ # This would be a good case for the `match` statement (Python 3.10+)
1005
+ is_mps = sample.device.type == "mps"
1006
+ if isinstance(timestep, float):
1007
+ dtype = torch.float32 if is_mps else torch.float64
1008
+ else:
1009
+ dtype = torch.int32 if is_mps else torch.int64
1010
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1011
+ elif len(timesteps.shape) == 0:
1012
+ timesteps = timesteps[None].to(sample.device)
1013
+
1014
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1015
+ timesteps = timesteps.expand(sample.shape[0])
1016
+
1017
+ t_emb = self.time_proj(timesteps)
1018
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1019
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1020
+ # there might be better ways to encapsulate this.
1021
+ t_emb = t_emb.to(dtype=sample.dtype)
1022
+ return t_emb
1023
+
1024
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
1025
+ class_emb = None
1026
+ if self.class_embedding is not None:
1027
+ if class_labels is None:
1028
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1029
+
1030
+ if self.config.class_embed_type == "timestep":
1031
+ class_labels = self.time_proj(class_labels)
1032
+
1033
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1034
+ # there might be better ways to encapsulate this.
1035
+ class_labels = class_labels.to(dtype=sample.dtype)
1036
+
1037
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1038
+ return class_emb
1039
+
1040
+ def get_aug_embed(
1041
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1042
+ ) -> Optional[torch.Tensor]:
1043
+ aug_emb = None
1044
+ if self.config.addition_embed_type == "text":
1045
+ aug_emb = self.add_embedding(encoder_hidden_states)
1046
+ elif self.config.addition_embed_type == "text_image":
1047
+ # Kandinsky 2.1 - style
1048
+ if "image_embeds" not in added_cond_kwargs:
1049
+ raise ValueError(
1050
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1051
+ )
1052
+
1053
+ image_embs = added_cond_kwargs.get("image_embeds")
1054
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1055
+ aug_emb = self.add_embedding(text_embs, image_embs)
1056
+ elif self.config.addition_embed_type == "text_time":
1057
+ # SDXL - style
1058
+ if "text_embeds" not in added_cond_kwargs:
1059
+ raise ValueError(
1060
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1061
+ )
1062
+ text_embeds = added_cond_kwargs.get("text_embeds")
1063
+ if "time_ids" not in added_cond_kwargs:
1064
+ raise ValueError(
1065
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1066
+ )
1067
+ time_ids = added_cond_kwargs.get("time_ids")
1068
+ time_embeds = self.add_time_proj(time_ids.flatten())
1069
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1070
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1071
+ add_embeds = add_embeds.to(emb.dtype)
1072
+ aug_emb = self.add_embedding(add_embeds)
1073
+ elif self.config.addition_embed_type == "image":
1074
+ # Kandinsky 2.2 - style
1075
+ if "image_embeds" not in added_cond_kwargs:
1076
+ raise ValueError(
1077
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1078
+ )
1079
+ image_embs = added_cond_kwargs.get("image_embeds")
1080
+ aug_emb = self.add_embedding(image_embs)
1081
+ elif self.config.addition_embed_type == "image_hint":
1082
+ # Kandinsky 2.2 - style
1083
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
1084
+ raise ValueError(
1085
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1086
+ )
1087
+ image_embs = added_cond_kwargs.get("image_embeds")
1088
+ hint = added_cond_kwargs.get("hint")
1089
+ aug_emb = self.add_embedding(image_embs, hint)
1090
+ return aug_emb
1091
+
1092
+ def process_encoder_hidden_states(
1093
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1094
+ ) -> torch.Tensor:
1095
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1096
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1097
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1098
+ # Kandinsky 2.1 - style
1099
+ if "image_embeds" not in added_cond_kwargs:
1100
+ raise ValueError(
1101
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1102
+ )
1103
+
1104
+ image_embeds = added_cond_kwargs.get("image_embeds")
1105
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1106
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1107
+ # Kandinsky 2.2 - style
1108
+ if "image_embeds" not in added_cond_kwargs:
1109
+ raise ValueError(
1110
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1111
+ )
1112
+ image_embeds = added_cond_kwargs.get("image_embeds")
1113
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1114
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1115
+ if "image_embeds" not in added_cond_kwargs:
1116
+ raise ValueError(
1117
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1118
+ )
1119
+ image_embeds = added_cond_kwargs.get("image_embeds")
1120
+ image_embeds = self.encoder_hid_proj(image_embeds)
1121
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1122
+ return encoder_hidden_states
1123
+
1124
+ def forward(
1125
+ self,
1126
+ sample: torch.FloatTensor,
1127
+ timestep: Union[torch.Tensor, float, int],
1128
+ encoder_hidden_states: torch.Tensor,
1129
+ class_labels: Optional[torch.Tensor] = None,
1130
+ timestep_cond: Optional[torch.Tensor] = None,
1131
+ attention_mask: Optional[torch.Tensor] = None,
1132
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1133
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1134
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1135
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1136
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1137
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1138
+ return_dict: bool = True,
1139
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1140
+ r"""
1141
+ The [`UNet2DConditionModel`] forward method.
1142
+
1143
+ Args:
1144
+ sample (`torch.FloatTensor`):
1145
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1146
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1147
+ encoder_hidden_states (`torch.FloatTensor`):
1148
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1149
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1150
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1151
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1152
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1153
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1154
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1155
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1156
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1157
+ negative values to the attention scores corresponding to "discard" tokens.
1158
+ cross_attention_kwargs (`dict`, *optional*):
1159
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1160
+ `self.processor` in
1161
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1162
+ added_cond_kwargs: (`dict`, *optional*):
1163
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1164
+ are passed along to the UNet blocks.
1165
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1166
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1167
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1168
+ A tensor that if specified is added to the residual of the middle unet block.
1169
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1170
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1171
+ encoder_attention_mask (`torch.Tensor`):
1172
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1173
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1174
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1175
+ return_dict (`bool`, *optional*, defaults to `True`):
1176
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1177
+ tuple.
1178
+
1179
+ Returns:
1180
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1181
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1182
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1183
+ """
1184
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1185
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1186
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1187
+ # on the fly if necessary.
1188
+ default_overall_up_factor = 2**self.num_upsamplers
1189
+
1190
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1191
+ forward_upsample_size = False
1192
+ upsample_size = None
1193
+
1194
+ for dim in sample.shape[-2:]:
1195
+ if dim % default_overall_up_factor != 0:
1196
+ # Forward upsample size to force interpolation output size.
1197
+ forward_upsample_size = True
1198
+ break
1199
+
1200
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1201
+ # expects mask of shape:
1202
+ # [batch, key_tokens]
1203
+ # adds singleton query_tokens dimension:
1204
+ # [batch, 1, key_tokens]
1205
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1206
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1207
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1208
+ if attention_mask is not None:
1209
+ # assume that mask is expressed as:
1210
+ # (1 = keep, 0 = discard)
1211
+ # convert mask into a bias that can be added to attention scores:
1212
+ # (keep = +0, discard = -10000.0)
1213
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1214
+ attention_mask = attention_mask.unsqueeze(1)
1215
+
1216
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1217
+ if encoder_attention_mask is not None:
1218
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1219
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1220
+
1221
+ # 0. center input if necessary
1222
+ if self.config.center_input_sample:
1223
+ sample = 2 * sample - 1.0
1224
+
1225
+ # 1. time
1226
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1227
+ emb = self.time_embedding(t_emb, timestep_cond)
1228
+ aug_emb = None
1229
+
1230
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1231
+ if class_emb is not None:
1232
+ if self.config.class_embeddings_concat:
1233
+ emb = torch.cat([emb, class_emb], dim=-1)
1234
+ else:
1235
+ emb = emb + class_emb
1236
+
1237
+ aug_emb = self.get_aug_embed(
1238
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1239
+ )
1240
+ if self.config.addition_embed_type == "image_hint":
1241
+ aug_emb, hint = aug_emb
1242
+ sample = torch.cat([sample, hint], dim=1)
1243
+
1244
+ emb = emb + aug_emb if aug_emb is not None else emb
1245
+
1246
+ if self.time_embed_act is not None:
1247
+ emb = self.time_embed_act(emb)
1248
+
1249
+ encoder_hidden_states = self.process_encoder_hidden_states(
1250
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1251
+ )
1252
+
1253
+ # 2. pre-process
1254
+ sample = self.conv_in(sample)
1255
+
1256
+ # 2.5 GLIGEN position net
1257
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1258
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1259
+ gligen_args = cross_attention_kwargs.pop("gligen")
1260
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1261
+
1262
+ # 3. down
1263
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1264
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1265
+ if cross_attention_kwargs is not None:
1266
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1267
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1268
+ else:
1269
+ lora_scale = 1.0
1270
+
1271
+ if USE_PEFT_BACKEND:
1272
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1273
+ scale_lora_layers(self, lora_scale)
1274
+
1275
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1276
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1277
+ is_adapter = down_intrablock_additional_residuals is not None
1278
+ # maintain backward compatibility for legacy usage, where
1279
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1280
+ # but can only use one or the other
1281
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1282
+ deprecate(
1283
+ "T2I should not use down_block_additional_residuals",
1284
+ "1.3.0",
1285
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1286
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1287
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1288
+ standard_warn=False,
1289
+ )
1290
+ down_intrablock_additional_residuals = down_block_additional_residuals
1291
+ is_adapter = True
1292
+
1293
+ down_block_res_samples = (sample,)
1294
+ for downsample_block in self.down_blocks:
1295
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1296
+ # For t2i-adapter CrossAttnDownBlock2D
1297
+ additional_residuals = {}
1298
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1299
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1300
+
1301
+ sample, res_samples = downsample_block(
1302
+ hidden_states=sample,
1303
+ temb=emb,
1304
+ encoder_hidden_states=encoder_hidden_states,
1305
+ attention_mask=attention_mask,
1306
+ cross_attention_kwargs=cross_attention_kwargs,
1307
+ encoder_attention_mask=encoder_attention_mask,
1308
+ **additional_residuals,
1309
+ )
1310
+ else:
1311
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1312
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1313
+ sample += down_intrablock_additional_residuals.pop(0)
1314
+
1315
+ down_block_res_samples += res_samples
1316
+
1317
+ if is_controlnet:
1318
+ new_down_block_res_samples = ()
1319
+
1320
+ for down_block_additional_residual, down_block_res_sample in zip(
1321
+ down_block_additional_residuals, down_block_res_samples
1322
+ ):
1323
+ down_block_res_sample_tuple = (down_block_additional_residual, down_block_res_sample)
1324
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample_tuple,)
1325
+
1326
+ down_block_res_samples = new_down_block_res_samples
1327
+
1328
+ # 4. mid
1329
+ if self.mid_block is not None:
1330
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1331
+ sample = self.mid_block(
1332
+ sample,
1333
+ emb,
1334
+ encoder_hidden_states=encoder_hidden_states,
1335
+ attention_mask=attention_mask,
1336
+ cross_attention_kwargs=cross_attention_kwargs,
1337
+ encoder_attention_mask=encoder_attention_mask,
1338
+ )
1339
+ else:
1340
+ sample = self.mid_block(sample, emb)
1341
+
1342
+ # To support T2I-Adapter-XL
1343
+ if (
1344
+ is_adapter
1345
+ and len(down_intrablock_additional_residuals) > 0
1346
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1347
+ ):
1348
+ sample += down_intrablock_additional_residuals.pop(0)
1349
+
1350
+ if is_controlnet:
1351
+ sample = self.mid_zero_SFT((mid_block_additional_residual, sample),)
1352
+
1353
+ # 5. up
1354
+ for i, upsample_block in enumerate(self.up_blocks):
1355
+ is_final_block = i == len(self.up_blocks) - 1
1356
+
1357
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1358
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1359
+
1360
+ # if we have not reached the final block and need to forward the
1361
+ # upsample size, we do it here
1362
+ if not is_final_block and forward_upsample_size:
1363
+ upsample_size = down_block_res_samples[-1].shape[2:]
1364
+
1365
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1366
+ sample = upsample_block(
1367
+ hidden_states=sample,
1368
+ temb=emb,
1369
+ res_hidden_states_tuple=res_samples,
1370
+ encoder_hidden_states=encoder_hidden_states,
1371
+ cross_attention_kwargs=cross_attention_kwargs,
1372
+ upsample_size=upsample_size,
1373
+ attention_mask=attention_mask,
1374
+ encoder_attention_mask=encoder_attention_mask,
1375
+ )
1376
+ else:
1377
+ sample = upsample_block(
1378
+ hidden_states=sample,
1379
+ temb=emb,
1380
+ res_hidden_states_tuple=res_samples,
1381
+ upsample_size=upsample_size,
1382
+ )
1383
+
1384
+ # 6. post-process
1385
+ if self.conv_norm_out:
1386
+ sample = self.conv_norm_out(sample)
1387
+ sample = self.conv_act(sample)
1388
+ sample = self.conv_out(sample)
1389
+
1390
+ if USE_PEFT_BACKEND:
1391
+ # remove `lora_scale` from each PEFT layer
1392
+ unscale_lora_layers(self, lora_scale)
1393
+
1394
+ if not return_dict:
1395
+ return (sample,)
1396
+
1397
+ return UNet2DConditionOutput(sample=sample)
module/unet/unet_2d_ZeroSFT_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
pipelines/sdxl_instantir.py ADDED
@@ -0,0 +1,1740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import (
24
+ CLIPImageProcessor,
25
+ CLIPTextModel,
26
+ CLIPTextModelWithProjection,
27
+ CLIPTokenizer,
28
+ CLIPVisionModelWithProjection,
29
+ )
30
+
31
+ from diffusers.utils.import_utils import is_invisible_watermark_available
32
+
33
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
34
+ from diffusers.loaders import (
35
+ FromSingleFileMixin,
36
+ IPAdapterMixin,
37
+ StableDiffusionXLLoraLoaderMixin,
38
+ TextualInversionLoaderMixin,
39
+ )
40
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
41
+ from diffusers.models.attention_processor import (
42
+ AttnProcessor2_0,
43
+ LoRAAttnProcessor2_0,
44
+ LoRAXFormersAttnProcessor,
45
+ XFormersAttnProcessor,
46
+ )
47
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
48
+ from diffusers.schedulers import KarrasDiffusionSchedulers, LCMScheduler
49
+ from diffusers.utils import (
50
+ USE_PEFT_BACKEND,
51
+ deprecate,
52
+ logging,
53
+ replace_example_docstring,
54
+ scale_lora_layers,
55
+ unscale_lora_layers,
56
+ convert_unet_state_dict_to_peft
57
+ )
58
+ from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
59
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
60
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
61
+
62
+
63
+ if is_invisible_watermark_available():
64
+ from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
65
+
66
+ from peft import LoraConfig, set_peft_model_state_dict
67
+ from module.aggregator import Aggregator
68
+
69
+
70
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
+
72
+
73
+ EXAMPLE_DOC_STRING = """
74
+ Examples:
75
+ ```py
76
+ >>> # !pip install diffusers pillow transformers accelerate
77
+ >>> import torch
78
+ >>> from PIL import Image
79
+
80
+ >>> from diffusers import DDPMScheduler
81
+ >>> from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
82
+
83
+ >>> from module.ip_adapter.utils import load_adapter_to_pipe
84
+ >>> from pipelines.sdxl_instantir import InstantIRPipeline
85
+
86
+ >>> # download models under ./models
87
+ >>> dcp_adapter = f'./models/adapter.pt'
88
+ >>> previewer_lora_path = f'./models'
89
+ >>> instantir_path = f'./models/aggregator.pt'
90
+
91
+ >>> # load pretrained models
92
+ >>> pipe = InstantIRPipeline.from_pretrained(
93
+ ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
94
+ ... )
95
+ >>> # load adapter
96
+ >>> load_adapter_to_pipe(
97
+ ... pipe,
98
+ ... dcp_adapter,
99
+ ... image_encoder_or_path = 'facebook/dinov2-large',
100
+ ... )
101
+ >>> # load previewer lora
102
+ >>> pipe.prepare_previewers(previewer_lora_path)
103
+ >>> pipe.scheduler = DDPMScheduler.from_pretrained('stabilityai/stable-diffusion-xl-base-1.0', subfolder="scheduler")
104
+ >>> lcm_scheduler = LCMSingleStepScheduler.from_config(pipe.scheduler.config)
105
+
106
+ >>> # load aggregator weights
107
+ >>> pretrained_state_dict = torch.load(instantir_path)
108
+ >>> pipe.aggregator.load_state_dict(pretrained_state_dict)
109
+
110
+ >>> # send to GPU and fp16
111
+ >>> pipe.to(device="cuda", dtype=torch.float16)
112
+ >>> pipe.aggregator.to(device="cuda", dtype=torch.float16)
113
+ >>> pipe.enable_model_cpu_offload()
114
+
115
+ >>> # load a broken image
116
+ >>> low_quality_image = Image.open('path/to/your-image').convert("RGB")
117
+
118
+ >>> # restoration
119
+ >>> image = pipe(
120
+ ... image=low_quality_image,
121
+ ... previewer_scheduler=lcm_scheduler,
122
+ ... ).images[0]
123
+ ```
124
+ """
125
+
126
+ LCM_LORA_MODULES = [
127
+ "to_q",
128
+ "to_k",
129
+ "to_v",
130
+ "to_out.0",
131
+ "proj_in",
132
+ "proj_out",
133
+ "ff.net.0.proj",
134
+ "ff.net.2",
135
+ "conv1",
136
+ "conv2",
137
+ "conv_shortcut",
138
+ "downsamplers.0.conv",
139
+ "upsamplers.0.conv",
140
+ "time_emb_proj",
141
+ ]
142
+ PREVIEWER_LORA_MODULES = [
143
+ "to_q",
144
+ "to_kv",
145
+ "0.to_out",
146
+ "attn1.to_k",
147
+ "attn1.to_v",
148
+ "to_k_ip",
149
+ "to_v_ip",
150
+ "ln_k_ip.linear",
151
+ "ln_v_ip.linear",
152
+ "to_out.0",
153
+ "proj_in",
154
+ "proj_out",
155
+ "ff.net.0.proj",
156
+ "ff.net.2",
157
+ "conv1",
158
+ "conv2",
159
+ "conv_shortcut",
160
+ "downsamplers.0.conv",
161
+ "upsamplers.0.conv",
162
+ "time_emb_proj",
163
+ ]
164
+
165
+
166
+ def remove_attn2(model):
167
+ def recursive_find_module(name, module):
168
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
169
+ elif "resnets" in name: return
170
+ if hasattr(module, "attn2"):
171
+ setattr(module, "attn2", None)
172
+ setattr(module, "norm2", None)
173
+ return
174
+ for sub_name, sub_module in module.named_children():
175
+ recursive_find_module(f"{name}.{sub_name}", sub_module)
176
+
177
+ for name, module in model.named_children():
178
+ recursive_find_module(name, module)
179
+
180
+
181
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
182
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
183
+ """
184
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
185
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
186
+ """
187
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
188
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
189
+ # rescale the results from guidance (fixes overexposure)
190
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
191
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
192
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
193
+ return noise_cfg
194
+
195
+
196
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
197
+ def retrieve_timesteps(
198
+ scheduler,
199
+ num_inference_steps: Optional[int] = None,
200
+ device: Optional[Union[str, torch.device]] = None,
201
+ timesteps: Optional[List[int]] = None,
202
+ **kwargs,
203
+ ):
204
+ """
205
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
206
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
207
+
208
+ Args:
209
+ scheduler (`SchedulerMixin`):
210
+ The scheduler to get timesteps from.
211
+ num_inference_steps (`int`):
212
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
213
+ must be `None`.
214
+ device (`str` or `torch.device`, *optional*):
215
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
216
+ timesteps (`List[int]`, *optional*):
217
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
218
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
219
+ must be `None`.
220
+
221
+ Returns:
222
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
223
+ second element is the number of inference steps.
224
+ """
225
+ if timesteps is not None:
226
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
227
+ if not accepts_timesteps:
228
+ raise ValueError(
229
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
230
+ f" timestep schedules. Please check whether you are using the correct scheduler."
231
+ )
232
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
233
+ timesteps = scheduler.timesteps
234
+ num_inference_steps = len(timesteps)
235
+ else:
236
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
237
+ timesteps = scheduler.timesteps
238
+ return timesteps, num_inference_steps
239
+
240
+
241
+ class InstantIRPipeline(
242
+ DiffusionPipeline,
243
+ StableDiffusionMixin,
244
+ TextualInversionLoaderMixin,
245
+ StableDiffusionXLLoraLoaderMixin,
246
+ IPAdapterMixin,
247
+ FromSingleFileMixin,
248
+ ):
249
+ r"""
250
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
251
+
252
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
253
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
254
+
255
+ The pipeline also inherits the following loading methods:
256
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
257
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
258
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
259
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
260
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
261
+
262
+ Args:
263
+ vae ([`AutoencoderKL`]):
264
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
265
+ text_encoder ([`~transformers.CLIPTextModel`]):
266
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
267
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
268
+ Second frozen text-encoder
269
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
270
+ tokenizer ([`~transformers.CLIPTokenizer`]):
271
+ A `CLIPTokenizer` to tokenize text.
272
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
273
+ A `CLIPTokenizer` to tokenize text.
274
+ unet ([`UNet2DConditionModel`]):
275
+ A `UNet2DConditionModel` to denoise the encoded image latents.
276
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
277
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
278
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
279
+ additional conditioning.
280
+ scheduler ([`SchedulerMixin`]):
281
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
282
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
283
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
284
+ Whether the negative prompt embeddings should always be set to 0. Also see the config of
285
+ `stabilityai/stable-diffusion-xl-base-1-0`.
286
+ add_watermarker (`bool`, *optional*):
287
+ Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
288
+ watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
289
+ watermarker is used.
290
+ """
291
+
292
+ # leave controlnet out on purpose because it iterates with unet
293
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
294
+ _optional_components = [
295
+ "tokenizer",
296
+ "tokenizer_2",
297
+ "text_encoder",
298
+ "text_encoder_2",
299
+ "feature_extractor",
300
+ "image_encoder",
301
+ ]
302
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
303
+
304
+ def __init__(
305
+ self,
306
+ vae: AutoencoderKL,
307
+ text_encoder: CLIPTextModel,
308
+ text_encoder_2: CLIPTextModelWithProjection,
309
+ tokenizer: CLIPTokenizer,
310
+ tokenizer_2: CLIPTokenizer,
311
+ unet: UNet2DConditionModel,
312
+ scheduler: KarrasDiffusionSchedulers,
313
+ aggregator: Aggregator = None,
314
+ force_zeros_for_empty_prompt: bool = True,
315
+ add_watermarker: Optional[bool] = None,
316
+ feature_extractor: CLIPImageProcessor = None,
317
+ image_encoder: CLIPVisionModelWithProjection = None,
318
+ ):
319
+ super().__init__()
320
+
321
+ if aggregator is None:
322
+ aggregator = Aggregator.from_unet(unet)
323
+ remove_attn2(aggregator)
324
+
325
+ self.register_modules(
326
+ vae=vae,
327
+ text_encoder=text_encoder,
328
+ text_encoder_2=text_encoder_2,
329
+ tokenizer=tokenizer,
330
+ tokenizer_2=tokenizer_2,
331
+ unet=unet,
332
+ aggregator=aggregator,
333
+ scheduler=scheduler,
334
+ feature_extractor=feature_extractor,
335
+ image_encoder=image_encoder,
336
+ )
337
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
338
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
339
+ self.control_image_processor = VaeImageProcessor(
340
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=True
341
+ )
342
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
343
+
344
+ if add_watermarker:
345
+ self.watermark = StableDiffusionXLWatermarker()
346
+ else:
347
+ self.watermark = None
348
+
349
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
350
+
351
+ def prepare_previewers(self, previewer_lora_path: str, use_lcm=False):
352
+ if use_lcm:
353
+ lora_state_dict, alpha_dict = self.lora_state_dict(
354
+ previewer_lora_path,
355
+ )
356
+ else:
357
+ lora_state_dict, alpha_dict = self.lora_state_dict(
358
+ previewer_lora_path,
359
+ weight_name="previewer_lora_weights.bin"
360
+ )
361
+ unet_state_dict = {
362
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
363
+ }
364
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
365
+ lora_state_dict = dict()
366
+ for k, v in unet_state_dict.items():
367
+ if "ip" in k:
368
+ k = k.replace("attn2", "attn2.processor")
369
+ lora_state_dict[k] = v
370
+ else:
371
+ lora_state_dict[k] = v
372
+ if alpha_dict:
373
+ lora_alpha = next(iter(alpha_dict.values()))
374
+ else:
375
+ lora_alpha = 1
376
+ logger.info(f"use lora alpha {lora_alpha}")
377
+ lora_config = LoraConfig(
378
+ r=64,
379
+ target_modules=LCM_LORA_MODULES if use_lcm else PREVIEWER_LORA_MODULES,
380
+ lora_alpha=lora_alpha,
381
+ lora_dropout=0.0,
382
+ )
383
+
384
+ adapter_name = "lcm" if use_lcm else "previewer"
385
+ self.unet.add_adapter(lora_config, adapter_name)
386
+ incompatible_keys = set_peft_model_state_dict(self.unet, lora_state_dict, adapter_name=adapter_name)
387
+ if incompatible_keys is not None:
388
+ # check only for unexpected keys
389
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
390
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
391
+ if unexpected_keys:
392
+ raise ValueError(
393
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
394
+ f" {unexpected_keys}. "
395
+ )
396
+ self.unet.disable_adapters()
397
+
398
+ return lora_alpha
399
+
400
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
401
+ def encode_prompt(
402
+ self,
403
+ prompt: str,
404
+ prompt_2: Optional[str] = None,
405
+ device: Optional[torch.device] = None,
406
+ num_images_per_prompt: int = 1,
407
+ do_classifier_free_guidance: bool = True,
408
+ negative_prompt: Optional[str] = None,
409
+ negative_prompt_2: Optional[str] = None,
410
+ prompt_embeds: Optional[torch.FloatTensor] = None,
411
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
412
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
413
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
414
+ lora_scale: Optional[float] = None,
415
+ clip_skip: Optional[int] = None,
416
+ ):
417
+ r"""
418
+ Encodes the prompt into text encoder hidden states.
419
+
420
+ Args:
421
+ prompt (`str` or `List[str]`, *optional*):
422
+ prompt to be encoded
423
+ prompt_2 (`str` or `List[str]`, *optional*):
424
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
425
+ used in both text-encoders
426
+ device: (`torch.device`):
427
+ torch device
428
+ num_images_per_prompt (`int`):
429
+ number of images that should be generated per prompt
430
+ do_classifier_free_guidance (`bool`):
431
+ whether to use classifier free guidance or not
432
+ negative_prompt (`str` or `List[str]`, *optional*):
433
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
434
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
435
+ less than `1`).
436
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
437
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
438
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
439
+ prompt_embeds (`torch.FloatTensor`, *optional*):
440
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
441
+ provided, text embeddings will be generated from `prompt` input argument.
442
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
443
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
444
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
445
+ argument.
446
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
447
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
448
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
449
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
450
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
451
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
452
+ input argument.
453
+ lora_scale (`float`, *optional*):
454
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
455
+ clip_skip (`int`, *optional*):
456
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
457
+ the output of the pre-final layer will be used for computing the prompt embeddings.
458
+ """
459
+ device = device or self._execution_device
460
+
461
+ # set lora scale so that monkey patched LoRA
462
+ # function of text encoder can correctly access it
463
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
464
+ self._lora_scale = lora_scale
465
+
466
+ # dynamically adjust the LoRA scale
467
+ if self.text_encoder is not None:
468
+ if not USE_PEFT_BACKEND:
469
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
470
+ else:
471
+ scale_lora_layers(self.text_encoder, lora_scale)
472
+
473
+ if self.text_encoder_2 is not None:
474
+ if not USE_PEFT_BACKEND:
475
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
476
+ else:
477
+ scale_lora_layers(self.text_encoder_2, lora_scale)
478
+
479
+ prompt = [prompt] if isinstance(prompt, str) else prompt
480
+
481
+ if prompt is not None:
482
+ batch_size = len(prompt)
483
+ else:
484
+ batch_size = prompt_embeds.shape[0]
485
+
486
+ # Define tokenizers and text encoders
487
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
488
+ text_encoders = (
489
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
490
+ )
491
+
492
+ if prompt_embeds is None:
493
+ prompt_2 = prompt_2 or prompt
494
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
495
+
496
+ # textual inversion: process multi-vector tokens if necessary
497
+ prompt_embeds_list = []
498
+ prompts = [prompt, prompt_2]
499
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
500
+ if isinstance(self, TextualInversionLoaderMixin):
501
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
502
+
503
+ text_inputs = tokenizer(
504
+ prompt,
505
+ padding="max_length",
506
+ max_length=tokenizer.model_max_length,
507
+ truncation=True,
508
+ return_tensors="pt",
509
+ )
510
+
511
+ text_input_ids = text_inputs.input_ids
512
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
513
+
514
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
515
+ text_input_ids, untruncated_ids
516
+ ):
517
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
518
+ logger.warning(
519
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
520
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
521
+ )
522
+
523
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
524
+
525
+ # We are only ALWAYS interested in the pooled output of the final text encoder
526
+ pooled_prompt_embeds = prompt_embeds[0]
527
+ if clip_skip is None:
528
+ prompt_embeds = prompt_embeds.hidden_states[-2]
529
+ else:
530
+ # "2" because SDXL always indexes from the penultimate layer.
531
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
532
+
533
+ prompt_embeds_list.append(prompt_embeds)
534
+
535
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
536
+
537
+ # get unconditional embeddings for classifier free guidance
538
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
539
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
540
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
541
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
542
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
543
+ negative_prompt = negative_prompt or ""
544
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
545
+
546
+ # normalize str to list
547
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
548
+ negative_prompt_2 = (
549
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
550
+ )
551
+
552
+ uncond_tokens: List[str]
553
+ if prompt is not None and type(prompt) is not type(negative_prompt):
554
+ raise TypeError(
555
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
556
+ f" {type(prompt)}."
557
+ )
558
+ elif batch_size != len(negative_prompt):
559
+ raise ValueError(
560
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
561
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
562
+ " the batch size of `prompt`."
563
+ )
564
+ else:
565
+ uncond_tokens = [negative_prompt, negative_prompt_2]
566
+
567
+ negative_prompt_embeds_list = []
568
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
569
+ if isinstance(self, TextualInversionLoaderMixin):
570
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
571
+
572
+ max_length = prompt_embeds.shape[1]
573
+ uncond_input = tokenizer(
574
+ negative_prompt,
575
+ padding="max_length",
576
+ max_length=max_length,
577
+ truncation=True,
578
+ return_tensors="pt",
579
+ )
580
+
581
+ negative_prompt_embeds = text_encoder(
582
+ uncond_input.input_ids.to(device),
583
+ output_hidden_states=True,
584
+ )
585
+ # We are only ALWAYS interested in the pooled output of the final text encoder
586
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
587
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
588
+
589
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
590
+
591
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
592
+
593
+ if self.text_encoder_2 is not None:
594
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
595
+ else:
596
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
597
+
598
+ bs_embed, seq_len, _ = prompt_embeds.shape
599
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
600
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
601
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
602
+
603
+ if do_classifier_free_guidance:
604
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
605
+ seq_len = negative_prompt_embeds.shape[1]
606
+
607
+ if self.text_encoder_2 is not None:
608
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
609
+ else:
610
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
611
+
612
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
613
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
614
+
615
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
616
+ bs_embed * num_images_per_prompt, -1
617
+ )
618
+ if do_classifier_free_guidance:
619
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
620
+ bs_embed * num_images_per_prompt, -1
621
+ )
622
+
623
+ if self.text_encoder is not None:
624
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
625
+ # Retrieve the original scale by scaling back the LoRA layers
626
+ unscale_lora_layers(self.text_encoder, lora_scale)
627
+
628
+ if self.text_encoder_2 is not None:
629
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
630
+ # Retrieve the original scale by scaling back the LoRA layers
631
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
632
+
633
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
634
+
635
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
636
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
637
+ dtype = next(self.image_encoder.parameters()).dtype
638
+
639
+ if not isinstance(image, torch.Tensor):
640
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
641
+
642
+ image = image.to(device=device, dtype=dtype)
643
+ if output_hidden_states:
644
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
645
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
646
+ uncond_image_enc_hidden_states = self.image_encoder(
647
+ torch.zeros_like(image), output_hidden_states=True
648
+ ).hidden_states[-2]
649
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
650
+ num_images_per_prompt, dim=0
651
+ )
652
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
653
+ else:
654
+ if isinstance(self.image_encoder, CLIPVisionModelWithProjection):
655
+ # CLIP image encoder.
656
+ image_embeds = self.image_encoder(image).image_embeds
657
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
658
+ uncond_image_embeds = torch.zeros_like(image_embeds)
659
+ else:
660
+ # DINO image encoder.
661
+ image_embeds = self.image_encoder(image).last_hidden_state
662
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
663
+ uncond_image_embeds = self.image_encoder(
664
+ torch.zeros_like(image)
665
+ ).last_hidden_state
666
+ uncond_image_embeds = uncond_image_embeds.repeat_interleave(
667
+ num_images_per_prompt, dim=0
668
+ )
669
+
670
+ return image_embeds, uncond_image_embeds
671
+
672
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
673
+ def prepare_ip_adapter_image_embeds(
674
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
675
+ ):
676
+ if ip_adapter_image_embeds is None:
677
+ if not isinstance(ip_adapter_image, list):
678
+ ip_adapter_image = [ip_adapter_image]
679
+
680
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
681
+ if isinstance(ip_adapter_image[0], list):
682
+ raise ValueError(
683
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
684
+ )
685
+ else:
686
+ logger.warning(
687
+ f"Got {len(ip_adapter_image)} images for {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
688
+ " By default, these images will be sent to each IP-Adapter. If this is not your use-case, please specify `ip_adapter_image` as a list of image-list, with"
689
+ f" length equals to the number of IP-Adapters."
690
+ )
691
+ ip_adapter_image = [ip_adapter_image] * len(self.unet.encoder_hid_proj.image_projection_layers)
692
+
693
+ image_embeds = []
694
+ for single_ip_adapter_image, image_proj_layer in zip(
695
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
696
+ ):
697
+ output_hidden_state = isinstance(self.image_encoder, CLIPVisionModelWithProjection) and not isinstance(image_proj_layer, ImageProjection)
698
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
699
+ single_ip_adapter_image, device, 1, output_hidden_state
700
+ )
701
+ single_image_embeds = torch.stack([single_image_embeds] * (num_images_per_prompt//single_image_embeds.shape[0]), dim=0)
702
+ single_negative_image_embeds = torch.stack(
703
+ [single_negative_image_embeds] * (num_images_per_prompt//single_negative_image_embeds.shape[0]), dim=0
704
+ )
705
+
706
+ if do_classifier_free_guidance:
707
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
708
+ single_image_embeds = single_image_embeds.to(device)
709
+
710
+ image_embeds.append(single_image_embeds)
711
+ else:
712
+ repeat_dims = [1]
713
+ image_embeds = []
714
+ for single_image_embeds in ip_adapter_image_embeds:
715
+ if do_classifier_free_guidance:
716
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
717
+ single_image_embeds = single_image_embeds.repeat(
718
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
719
+ )
720
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
721
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
722
+ )
723
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
724
+ else:
725
+ single_image_embeds = single_image_embeds.repeat(
726
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
727
+ )
728
+ image_embeds.append(single_image_embeds)
729
+
730
+ return image_embeds
731
+
732
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
733
+ def prepare_extra_step_kwargs(self, generator, eta):
734
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
735
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
736
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
737
+ # and should be between [0, 1]
738
+
739
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
740
+ extra_step_kwargs = {}
741
+ if accepts_eta:
742
+ extra_step_kwargs["eta"] = eta
743
+
744
+ # check if the scheduler accepts generator
745
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
746
+ if accepts_generator:
747
+ extra_step_kwargs["generator"] = generator
748
+ return extra_step_kwargs
749
+
750
+ def check_inputs(
751
+ self,
752
+ prompt,
753
+ prompt_2,
754
+ image,
755
+ callback_steps,
756
+ negative_prompt=None,
757
+ negative_prompt_2=None,
758
+ prompt_embeds=None,
759
+ negative_prompt_embeds=None,
760
+ pooled_prompt_embeds=None,
761
+ ip_adapter_image=None,
762
+ ip_adapter_image_embeds=None,
763
+ negative_pooled_prompt_embeds=None,
764
+ controlnet_conditioning_scale=1.0,
765
+ control_guidance_start=0.0,
766
+ control_guidance_end=1.0,
767
+ callback_on_step_end_tensor_inputs=None,
768
+ ):
769
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
770
+ raise ValueError(
771
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
772
+ f" {type(callback_steps)}."
773
+ )
774
+
775
+ if callback_on_step_end_tensor_inputs is not None and not all(
776
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
777
+ ):
778
+ raise ValueError(
779
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
780
+ )
781
+
782
+ if prompt is not None and prompt_embeds is not None:
783
+ raise ValueError(
784
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
785
+ " only forward one of the two."
786
+ )
787
+ elif prompt_2 is not None and prompt_embeds is not None:
788
+ raise ValueError(
789
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
790
+ " only forward one of the two."
791
+ )
792
+ elif prompt is None and prompt_embeds is None:
793
+ raise ValueError(
794
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
795
+ )
796
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
797
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
798
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
799
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
800
+
801
+ if negative_prompt is not None and negative_prompt_embeds is not None:
802
+ raise ValueError(
803
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
804
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
805
+ )
806
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
807
+ raise ValueError(
808
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
809
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
810
+ )
811
+
812
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
813
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
814
+ raise ValueError(
815
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
816
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
817
+ f" {negative_prompt_embeds.shape}."
818
+ )
819
+
820
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
821
+ raise ValueError(
822
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
823
+ )
824
+
825
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
826
+ raise ValueError(
827
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
828
+ )
829
+
830
+ # Check `image`
831
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
832
+ self.aggregator, torch._dynamo.eval_frame.OptimizedModule
833
+ )
834
+ if (
835
+ isinstance(self.aggregator, Aggregator)
836
+ or is_compiled
837
+ and isinstance(self.aggregator._orig_mod, Aggregator)
838
+ ):
839
+ self.check_image(image, prompt, prompt_embeds)
840
+ else:
841
+ assert False
842
+
843
+ if control_guidance_start >= control_guidance_end:
844
+ raise ValueError(
845
+ f"control guidance start: {control_guidance_start} cannot be larger or equal to control guidance end: {control_guidance_end}."
846
+ )
847
+ if control_guidance_start < 0.0:
848
+ raise ValueError(f"control guidance start: {control_guidance_start} can't be smaller than 0.")
849
+ if control_guidance_end > 1.0:
850
+ raise ValueError(f"control guidance end: {control_guidance_end} can't be larger than 1.0.")
851
+
852
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
853
+ raise ValueError(
854
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
855
+ )
856
+
857
+ if ip_adapter_image_embeds is not None:
858
+ if not isinstance(ip_adapter_image_embeds, list):
859
+ raise ValueError(
860
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
861
+ )
862
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
863
+ raise ValueError(
864
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
865
+ )
866
+
867
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
868
+ def check_image(self, image, prompt, prompt_embeds):
869
+ image_is_pil = isinstance(image, PIL.Image.Image)
870
+ image_is_tensor = isinstance(image, torch.Tensor)
871
+ image_is_np = isinstance(image, np.ndarray)
872
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
873
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
874
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
875
+
876
+ if (
877
+ not image_is_pil
878
+ and not image_is_tensor
879
+ and not image_is_np
880
+ and not image_is_pil_list
881
+ and not image_is_tensor_list
882
+ and not image_is_np_list
883
+ ):
884
+ raise TypeError(
885
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
886
+ )
887
+
888
+ if image_is_pil:
889
+ image_batch_size = 1
890
+ else:
891
+ image_batch_size = len(image)
892
+
893
+ if prompt is not None and isinstance(prompt, str):
894
+ prompt_batch_size = 1
895
+ elif prompt is not None and isinstance(prompt, list):
896
+ prompt_batch_size = len(prompt)
897
+ elif prompt_embeds is not None:
898
+ prompt_batch_size = prompt_embeds.shape[0]
899
+
900
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
901
+ raise ValueError(
902
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
903
+ )
904
+
905
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
906
+ def prepare_image(
907
+ self,
908
+ image,
909
+ width,
910
+ height,
911
+ batch_size,
912
+ num_images_per_prompt,
913
+ device,
914
+ dtype,
915
+ do_classifier_free_guidance=False,
916
+ ):
917
+ image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
918
+ image_batch_size = image.shape[0]
919
+
920
+ if image_batch_size == 1:
921
+ repeat_by = batch_size
922
+ else:
923
+ # image batch size is the same as prompt batch size
924
+ repeat_by = num_images_per_prompt
925
+
926
+ image = image.repeat_interleave(repeat_by, dim=0)
927
+
928
+ image = image.to(device=device, dtype=dtype)
929
+
930
+ return image
931
+
932
+ @torch.no_grad()
933
+ def init_latents(self, latents, generator, timestep):
934
+ noise = torch.randn(latents.shape, generator=generator, device=self.vae.device, dtype=self.vae.dtype, layout=torch.strided)
935
+ bsz = latents.shape[0]
936
+ print(f"init latent at {timestep}")
937
+ timestep = torch.tensor([timestep]*bsz, device=self.vae.device)
938
+ # Note that the latents will be scaled aleady by scheduler.add_noise
939
+ latents = self.scheduler.add_noise(latents, noise, timestep)
940
+ return latents
941
+
942
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
943
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
944
+ shape = (
945
+ batch_size,
946
+ num_channels_latents,
947
+ int(height) // self.vae_scale_factor,
948
+ int(width) // self.vae_scale_factor,
949
+ )
950
+ if isinstance(generator, list) and len(generator) != batch_size:
951
+ raise ValueError(
952
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
953
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
954
+ )
955
+
956
+ if latents is None:
957
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
958
+ else:
959
+ latents = latents.to(device)
960
+
961
+ # scale the initial noise by the standard deviation required by the scheduler
962
+ latents = latents * self.scheduler.init_noise_sigma
963
+ return latents
964
+
965
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
966
+ def _get_add_time_ids(
967
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
968
+ ):
969
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
970
+
971
+ passed_add_embed_dim = (
972
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
973
+ )
974
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
975
+
976
+ if expected_add_embed_dim != passed_add_embed_dim:
977
+ raise ValueError(
978
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
979
+ )
980
+
981
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
982
+ return add_time_ids
983
+
984
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
985
+ def upcast_vae(self):
986
+ dtype = self.vae.dtype
987
+ self.vae.to(dtype=torch.float32)
988
+ use_torch_2_0_or_xformers = isinstance(
989
+ self.vae.decoder.mid_block.attentions[0].processor,
990
+ (
991
+ AttnProcessor2_0,
992
+ XFormersAttnProcessor,
993
+ LoRAXFormersAttnProcessor,
994
+ LoRAAttnProcessor2_0,
995
+ ),
996
+ )
997
+ # if xformers or torch_2_0 is used attention block does not need
998
+ # to be in float32 which can save lots of memory
999
+ if use_torch_2_0_or_xformers:
1000
+ self.vae.post_quant_conv.to(dtype)
1001
+ self.vae.decoder.conv_in.to(dtype)
1002
+ self.vae.decoder.mid_block.to(dtype)
1003
+
1004
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1005
+ def get_guidance_scale_embedding(
1006
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
1007
+ ) -> torch.FloatTensor:
1008
+ """
1009
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1010
+
1011
+ Args:
1012
+ w (`torch.Tensor`):
1013
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
1014
+ embedding_dim (`int`, *optional*, defaults to 512):
1015
+ Dimension of the embeddings to generate.
1016
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1017
+ Data type of the generated embeddings.
1018
+
1019
+ Returns:
1020
+ `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
1021
+ """
1022
+ assert len(w.shape) == 1
1023
+ w = w * 1000.0
1024
+
1025
+ half_dim = embedding_dim // 2
1026
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1027
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1028
+ emb = w.to(dtype)[:, None] * emb[None, :]
1029
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1030
+ if embedding_dim % 2 == 1: # zero pad
1031
+ emb = torch.nn.functional.pad(emb, (0, 1))
1032
+ assert emb.shape == (w.shape[0], embedding_dim)
1033
+ return emb
1034
+
1035
+ @property
1036
+ def guidance_scale(self):
1037
+ return self._guidance_scale
1038
+
1039
+ @property
1040
+ def guidance_rescale(self):
1041
+ return self._guidance_rescale
1042
+
1043
+ @property
1044
+ def clip_skip(self):
1045
+ return self._clip_skip
1046
+
1047
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1048
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1049
+ # corresponds to doing no classifier free guidance.
1050
+ @property
1051
+ def do_classifier_free_guidance(self):
1052
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1053
+
1054
+ @property
1055
+ def cross_attention_kwargs(self):
1056
+ return self._cross_attention_kwargs
1057
+
1058
+ @property
1059
+ def denoising_end(self):
1060
+ return self._denoising_end
1061
+
1062
+ @property
1063
+ def num_timesteps(self):
1064
+ return self._num_timesteps
1065
+
1066
+ @torch.no_grad()
1067
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1068
+ def __call__(
1069
+ self,
1070
+ prompt: Union[str, List[str]] = None,
1071
+ prompt_2: Optional[Union[str, List[str]]] = None,
1072
+ image: PipelineImageInput = None,
1073
+ height: Optional[int] = None,
1074
+ width: Optional[int] = None,
1075
+ num_inference_steps: int = 30,
1076
+ timesteps: List[int] = None,
1077
+ denoising_end: Optional[float] = None,
1078
+ guidance_scale: float = 7.0,
1079
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1080
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1081
+ num_images_per_prompt: Optional[int] = 1,
1082
+ eta: float = 0.0,
1083
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1084
+ latents: Optional[torch.FloatTensor] = None,
1085
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1086
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1087
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1088
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
1089
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1090
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
1091
+ output_type: Optional[str] = "pil",
1092
+ return_dict: bool = True,
1093
+ save_preview_row: bool = False,
1094
+ init_latents_with_lq: bool = True,
1095
+ multistep_restore: bool = False,
1096
+ adastep_restore: bool = False,
1097
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1098
+ guidance_rescale: float = 0.0,
1099
+ controlnet_conditioning_scale: float = 1.0,
1100
+ control_guidance_start: float = 0.0,
1101
+ control_guidance_end: float = 1.0,
1102
+ preview_start: float = 0.0,
1103
+ preview_end: float = 1.0,
1104
+ original_size: Tuple[int, int] = None,
1105
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1106
+ target_size: Tuple[int, int] = None,
1107
+ negative_original_size: Optional[Tuple[int, int]] = None,
1108
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1109
+ negative_target_size: Optional[Tuple[int, int]] = None,
1110
+ clip_skip: Optional[int] = None,
1111
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1112
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1113
+ previewer_scheduler: KarrasDiffusionSchedulers = None,
1114
+ reference_latents: Optional[torch.FloatTensor] = None,
1115
+ **kwargs,
1116
+ ):
1117
+ r"""
1118
+ The call function to the pipeline for generation.
1119
+
1120
+ Args:
1121
+ prompt (`str` or `List[str]`, *optional*):
1122
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
1123
+ prompt_2 (`str` or `List[str]`, *optional*):
1124
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1125
+ used in both text-encoders.
1126
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
1127
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
1128
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
1129
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
1130
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
1131
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
1132
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
1133
+ input to a single ControlNet.
1134
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1135
+ The height in pixels of the generated image. Anything below 512 pixels won't work well for
1136
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1137
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1138
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
1139
+ The width in pixels of the generated image. Anything below 512 pixels won't work well for
1140
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1141
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1142
+ num_inference_steps (`int`, *optional*, defaults to 50):
1143
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1144
+ expense of slower inference.
1145
+ timesteps (`List[int]`, *optional*):
1146
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1147
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1148
+ passed will be used. Must be in descending order.
1149
+ denoising_end (`float`, *optional*):
1150
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1151
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1152
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1153
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1154
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1155
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1156
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1157
+ A higher guidance scale value encourages the model to generate images closely linked to the text
1158
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
1159
+ negative_prompt (`str` or `List[str]`, *optional*):
1160
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
1161
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
1162
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1163
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
1164
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
1165
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1166
+ The number of images to generate per prompt.
1167
+ eta (`float`, *optional*, defaults to 0.0):
1168
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1169
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1170
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1171
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1172
+ generation deterministic.
1173
+ latents (`torch.FloatTensor`, *optional*):
1174
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
1175
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1176
+ tensor is generated by sampling using the supplied random `generator`.
1177
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1178
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
1179
+ provided, text embeddings are generated from the `prompt` input argument.
1180
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1181
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1182
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
1183
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1184
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
1185
+ not provided, pooled text embeddings are generated from `prompt` input argument.
1186
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1187
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
1188
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
1189
+ argument.
1190
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1191
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
1192
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1193
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1194
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1195
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1196
+ output_type (`str`, *optional*, defaults to `"pil"`):
1197
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
1198
+ return_dict (`bool`, *optional*, defaults to `True`):
1199
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1200
+ plain tuple.
1201
+ cross_attention_kwargs (`dict`, *optional*):
1202
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
1203
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1204
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1205
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1206
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1207
+ the corresponding scale as a list.
1208
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1209
+ The percentage of total steps at which the ControlNet starts applying.
1210
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1211
+ The percentage of total steps at which the ControlNet stops applying.
1212
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1213
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1214
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1215
+ explained in section 2.2 of
1216
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1217
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1218
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1219
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1220
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1221
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1222
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1223
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1224
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1225
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1226
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1227
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1228
+ micro-conditioning as explained in section 2.2 of
1229
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1230
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1231
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1232
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1233
+ micro-conditioning as explained in section 2.2 of
1234
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1235
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1236
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1237
+ To negatively condition the generation process based on a target image resolution. It should be as same
1238
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1239
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1240
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1241
+ clip_skip (`int`, *optional*):
1242
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1243
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1244
+ callback_on_step_end (`Callable`, *optional*):
1245
+ A function that calls at the end of each denoising steps during the inference. The function is called
1246
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
1247
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
1248
+ `callback_on_step_end_tensor_inputs`.
1249
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1250
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1251
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1252
+ `._callback_tensor_inputs` attribute of your pipeline class.
1253
+
1254
+ Examples:
1255
+
1256
+ Returns:
1257
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1258
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
1259
+ otherwise a `tuple` is returned containing the output images.
1260
+ """
1261
+
1262
+ callback = kwargs.pop("callback", None)
1263
+ callback_steps = kwargs.pop("callback_steps", None)
1264
+
1265
+ if callback is not None:
1266
+ deprecate(
1267
+ "callback",
1268
+ "1.0.0",
1269
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1270
+ )
1271
+ if callback_steps is not None:
1272
+ deprecate(
1273
+ "callback_steps",
1274
+ "1.0.0",
1275
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1276
+ )
1277
+
1278
+ aggregator = self.aggregator._orig_mod if is_compiled_module(self.aggregator) else self.aggregator
1279
+ if not isinstance(ip_adapter_image, list):
1280
+ ip_adapter_image = [ip_adapter_image] if ip_adapter_image is not None else [image]
1281
+
1282
+ # 1. Check inputs. Raise error if not correct
1283
+ self.check_inputs(
1284
+ prompt,
1285
+ prompt_2,
1286
+ image,
1287
+ callback_steps,
1288
+ negative_prompt,
1289
+ negative_prompt_2,
1290
+ prompt_embeds,
1291
+ negative_prompt_embeds,
1292
+ pooled_prompt_embeds,
1293
+ ip_adapter_image,
1294
+ ip_adapter_image_embeds,
1295
+ negative_pooled_prompt_embeds,
1296
+ controlnet_conditioning_scale,
1297
+ control_guidance_start,
1298
+ control_guidance_end,
1299
+ callback_on_step_end_tensor_inputs,
1300
+ )
1301
+
1302
+ self._guidance_scale = guidance_scale
1303
+ self._guidance_rescale = guidance_rescale
1304
+ self._clip_skip = clip_skip
1305
+ self._cross_attention_kwargs = cross_attention_kwargs
1306
+ self._denoising_end = denoising_end
1307
+
1308
+ # 2. Define call parameters
1309
+ if prompt is not None and isinstance(prompt, str):
1310
+ if not isinstance(image, PIL.Image.Image):
1311
+ batch_size = len(image)
1312
+ else:
1313
+ batch_size = 1
1314
+ prompt = [prompt] * batch_size
1315
+ elif prompt is not None and isinstance(prompt, list):
1316
+ batch_size = len(prompt)
1317
+ assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1)
1318
+ else:
1319
+ batch_size = prompt_embeds.shape[0]
1320
+ assert batch_size == len(image) or (isinstance(image, PIL.Image.Image) or len(image) == 1)
1321
+
1322
+ device = self._execution_device
1323
+
1324
+ # 3.1 Encode input prompt
1325
+ text_encoder_lora_scale = (
1326
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1327
+ )
1328
+ (
1329
+ prompt_embeds,
1330
+ negative_prompt_embeds,
1331
+ pooled_prompt_embeds,
1332
+ negative_pooled_prompt_embeds,
1333
+ ) = self.encode_prompt(
1334
+ prompt=prompt,
1335
+ prompt_2=prompt_2,
1336
+ device=device,
1337
+ num_images_per_prompt=num_images_per_prompt,
1338
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1339
+ negative_prompt=negative_prompt,
1340
+ negative_prompt_2=negative_prompt_2,
1341
+ prompt_embeds=prompt_embeds,
1342
+ negative_prompt_embeds=negative_prompt_embeds,
1343
+ pooled_prompt_embeds=pooled_prompt_embeds,
1344
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1345
+ lora_scale=text_encoder_lora_scale,
1346
+ clip_skip=self.clip_skip,
1347
+ )
1348
+
1349
+ # 3.2 Encode ip_adapter_image
1350
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1351
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1352
+ ip_adapter_image,
1353
+ ip_adapter_image_embeds,
1354
+ device,
1355
+ batch_size * num_images_per_prompt,
1356
+ self.do_classifier_free_guidance,
1357
+ )
1358
+
1359
+ # 4. Prepare image
1360
+ image = self.prepare_image(
1361
+ image=image,
1362
+ width=width,
1363
+ height=height,
1364
+ batch_size=batch_size * num_images_per_prompt,
1365
+ num_images_per_prompt=num_images_per_prompt,
1366
+ device=device,
1367
+ dtype=aggregator.dtype,
1368
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1369
+ )
1370
+ height, width = image.shape[-2:]
1371
+ if image.shape[1] != 4:
1372
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1373
+ if needs_upcasting:
1374
+ image = image.float()
1375
+ self.vae.to(dtype=torch.float32)
1376
+ image = self.vae.encode(image).latent_dist.sample()
1377
+ image = image * self.vae.config.scaling_factor
1378
+ if needs_upcasting:
1379
+ self.vae.to(dtype=torch.float16)
1380
+ image = image.to(dtype=torch.float16)
1381
+ else:
1382
+ height = int(height * self.vae_scale_factor)
1383
+ width = int(width * self.vae_scale_factor)
1384
+
1385
+ # 5. Prepare timesteps
1386
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1387
+
1388
+ # 6. Prepare latent variables
1389
+ if init_latents_with_lq:
1390
+ latents = self.init_latents(image, generator, timesteps[0])
1391
+ else:
1392
+ num_channels_latents = self.unet.config.in_channels
1393
+ latents = self.prepare_latents(
1394
+ batch_size * num_images_per_prompt,
1395
+ num_channels_latents,
1396
+ height,
1397
+ width,
1398
+ prompt_embeds.dtype,
1399
+ device,
1400
+ generator,
1401
+ latents,
1402
+ )
1403
+
1404
+ # 6.5 Optionally get Guidance Scale Embedding
1405
+ timestep_cond = None
1406
+ if self.unet.config.time_cond_proj_dim is not None:
1407
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1408
+ timestep_cond = self.get_guidance_scale_embedding(
1409
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1410
+ ).to(device=device, dtype=latents.dtype)
1411
+
1412
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1413
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1414
+
1415
+ # 7.1 Create tensor stating which controlnets to keep
1416
+ controlnet_keep = []
1417
+ previewing = []
1418
+ for i in range(len(timesteps)):
1419
+ keeps = 1.0 - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
1420
+ controlnet_keep.append(keeps)
1421
+ use_preview = 1.0 - float(i / len(timesteps) < preview_start or (i + 1) / len(timesteps) > preview_end)
1422
+ previewing.append(use_preview)
1423
+ if isinstance(controlnet_conditioning_scale, list):
1424
+ assert len(controlnet_conditioning_scale) == len(timesteps), f"{len(controlnet_conditioning_scale)} controlnet scales do not match number of sampling steps {len(timesteps)}"
1425
+ else:
1426
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet_keep)
1427
+
1428
+ # 7.2 Prepare added time ids & embeddings
1429
+ original_size = original_size or (height, width)
1430
+ target_size = target_size or (height, width)
1431
+
1432
+ add_text_embeds = pooled_prompt_embeds
1433
+ if self.text_encoder_2 is None:
1434
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1435
+ else:
1436
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1437
+
1438
+ add_time_ids = self._get_add_time_ids(
1439
+ original_size,
1440
+ crops_coords_top_left,
1441
+ target_size,
1442
+ dtype=prompt_embeds.dtype,
1443
+ text_encoder_projection_dim=text_encoder_projection_dim,
1444
+ )
1445
+
1446
+ if negative_original_size is not None and negative_target_size is not None:
1447
+ negative_add_time_ids = self._get_add_time_ids(
1448
+ negative_original_size,
1449
+ negative_crops_coords_top_left,
1450
+ negative_target_size,
1451
+ dtype=prompt_embeds.dtype,
1452
+ text_encoder_projection_dim=text_encoder_projection_dim,
1453
+ )
1454
+ else:
1455
+ negative_add_time_ids = add_time_ids
1456
+
1457
+ if self.do_classifier_free_guidance:
1458
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1459
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1460
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1461
+ image = torch.cat([image] * 2, dim=0)
1462
+
1463
+ prompt_embeds = prompt_embeds.to(device)
1464
+ add_text_embeds = add_text_embeds.to(device)
1465
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1466
+
1467
+ # 8. Denoising loop
1468
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1469
+
1470
+ # 8.1 Apply denoising_end
1471
+ if (
1472
+ self.denoising_end is not None
1473
+ and isinstance(self.denoising_end, float)
1474
+ and self.denoising_end > 0
1475
+ and self.denoising_end < 1
1476
+ ):
1477
+ discrete_timestep_cutoff = int(
1478
+ round(
1479
+ self.scheduler.config.num_train_timesteps
1480
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1481
+ )
1482
+ )
1483
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1484
+ timesteps = timesteps[:num_inference_steps]
1485
+
1486
+ is_unet_compiled = is_compiled_module(self.unet)
1487
+ is_aggregator_compiled = is_compiled_module(self.aggregator)
1488
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1489
+ previewer_mean = torch.zeros_like(latents)
1490
+ unet_mean = torch.zeros_like(latents)
1491
+ preview_factor = torch.ones(
1492
+ (latents.shape[0], *((1,) * (len(latents.shape) - 1))), dtype=latents.dtype, device=latents.device
1493
+ )
1494
+
1495
+ self._num_timesteps = len(timesteps)
1496
+ preview_row = []
1497
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1498
+ for i, t in enumerate(timesteps):
1499
+ # Relevant thread:
1500
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
1501
+ if (is_unet_compiled and is_aggregator_compiled) and is_torch_higher_equal_2_1:
1502
+ torch._inductor.cudagraph_mark_step_begin()
1503
+ # expand the latents if we are doing classifier free guidance
1504
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1505
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1506
+ prev_t = t
1507
+ unet_model_input = latent_model_input
1508
+
1509
+ added_cond_kwargs = {
1510
+ "text_embeds": add_text_embeds,
1511
+ "time_ids": add_time_ids,
1512
+ "image_embeds": image_embeds
1513
+ }
1514
+ aggregator_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1515
+
1516
+ # prepare time_embeds in advance as adapter input
1517
+ cross_attention_t_emb = self.unet.get_time_embed(sample=latent_model_input, timestep=t)
1518
+ cross_attention_emb = self.unet.time_embedding(cross_attention_t_emb, timestep_cond)
1519
+ cross_attention_aug_emb = None
1520
+
1521
+ cross_attention_aug_emb = self.unet.get_aug_embed(
1522
+ emb=cross_attention_emb,
1523
+ encoder_hidden_states=prompt_embeds,
1524
+ added_cond_kwargs=added_cond_kwargs
1525
+ )
1526
+
1527
+ cross_attention_emb = cross_attention_emb + cross_attention_aug_emb if cross_attention_aug_emb is not None else cross_attention_emb
1528
+
1529
+ if self.unet.time_embed_act is not None:
1530
+ cross_attention_emb = self.unet.time_embed_act(cross_attention_emb)
1531
+
1532
+ current_cross_attention_kwargs = {"temb": cross_attention_emb}
1533
+ if cross_attention_kwargs is not None:
1534
+ for k,v in cross_attention_kwargs.items():
1535
+ current_cross_attention_kwargs[k] = v
1536
+ self._cross_attention_kwargs = current_cross_attention_kwargs
1537
+
1538
+ # adaptive restoration factors
1539
+ adaRes_scale = preview_factor.to(latent_model_input.dtype).clamp(0.0, controlnet_conditioning_scale[i])
1540
+ cond_scale = adaRes_scale * controlnet_keep[i]
1541
+ cond_scale = torch.cat([cond_scale] * 2) if self.do_classifier_free_guidance else cond_scale
1542
+
1543
+ if (cond_scale>0.1).sum().item() > 0:
1544
+ if previewing[i] > 0:
1545
+ # preview with LCM
1546
+ self.unet.enable_adapters()
1547
+ preview_noise = self.unet(
1548
+ latent_model_input,
1549
+ t,
1550
+ encoder_hidden_states=prompt_embeds,
1551
+ timestep_cond=timestep_cond,
1552
+ cross_attention_kwargs=self.cross_attention_kwargs,
1553
+ added_cond_kwargs=added_cond_kwargs,
1554
+ return_dict=False,
1555
+ )[0]
1556
+ preview_latent = previewer_scheduler.step(
1557
+ preview_noise,
1558
+ t.to(dtype=torch.int64),
1559
+ # torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
1560
+ latent_model_input, # scaled latents here for compatibility
1561
+ return_dict=False
1562
+ )[0]
1563
+ self.unet.disable_adapters()
1564
+
1565
+ if self.do_classifier_free_guidance:
1566
+ preview_row.append(preview_latent.chunk(2)[1].to('cpu'))
1567
+ else:
1568
+ preview_row.append(preview_latent.to('cpu'))
1569
+ # Prepare 2nd order step.
1570
+ if multistep_restore and i+1 < len(timesteps):
1571
+ noise_preview = preview_noise.chunk(2)[1] if self.do_classifier_free_guidance else preview_noise
1572
+ first_step = self.scheduler.step(
1573
+ noise_preview, t, latents,
1574
+ **extra_step_kwargs, return_dict=True, step_forward=False
1575
+ )
1576
+ prev_t = timesteps[i + 1]
1577
+ unet_model_input = torch.cat([first_step.prev_sample] * 2) if self.do_classifier_free_guidance else first_step.prev_sample
1578
+ unet_model_input = self.scheduler.scale_model_input(unet_model_input, prev_t, heun_step=True)
1579
+
1580
+ elif reference_latents is not None:
1581
+ preview_latent = torch.cat([reference_latents] * 2) if self.do_classifier_free_guidance else reference_latents
1582
+ else:
1583
+ preview_latent = image
1584
+
1585
+ # Add fresh noise
1586
+ # preview_noise = torch.randn_like(preview_latent)
1587
+ # preview_latent = self.scheduler.add_noise(preview_latent, preview_noise, t)
1588
+
1589
+ preview_latent=preview_latent.to(dtype=next(aggregator.parameters()).dtype)
1590
+
1591
+ # Aggregator inference
1592
+ down_block_res_samples, mid_block_res_sample = aggregator(
1593
+ image,
1594
+ prev_t,
1595
+ encoder_hidden_states=prompt_embeds,
1596
+ controlnet_cond=preview_latent,
1597
+ # conditioning_scale=cond_scale,
1598
+ added_cond_kwargs=aggregator_added_cond_kwargs,
1599
+ return_dict=False,
1600
+ )
1601
+
1602
+ # aggregator features scaling
1603
+ down_block_res_samples = [sample*cond_scale for sample in down_block_res_samples]
1604
+ mid_block_res_sample = mid_block_res_sample*cond_scale
1605
+
1606
+ # predict the noise residual
1607
+ noise_pred = self.unet(
1608
+ unet_model_input,
1609
+ prev_t,
1610
+ encoder_hidden_states=prompt_embeds,
1611
+ timestep_cond=timestep_cond,
1612
+ cross_attention_kwargs=self.cross_attention_kwargs,
1613
+ down_block_additional_residuals=down_block_res_samples,
1614
+ mid_block_additional_residual=mid_block_res_sample,
1615
+ added_cond_kwargs=added_cond_kwargs,
1616
+ return_dict=False,
1617
+ )[0]
1618
+
1619
+ # perform guidance
1620
+ if self.do_classifier_free_guidance:
1621
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1622
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1623
+
1624
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1625
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1626
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1627
+
1628
+ # compute the previous noisy sample x_t -> x_t-1
1629
+ latents_dtype = latents.dtype
1630
+ unet_step = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True)
1631
+ latents = unet_step.prev_sample
1632
+
1633
+ # Update adaRes factors
1634
+ unet_pred_latent = unet_step.pred_original_sample
1635
+
1636
+ # Adaptive restoration.
1637
+ if adastep_restore:
1638
+ pred_x0_l2 = ((preview_latent[latents.shape[0]:].float()-unet_pred_latent.float())).pow(2).sum(dim=(1,2,3))
1639
+ previewer_l2 = ((preview_latent[latents.shape[0]:].float()-previewer_mean.float())).pow(2).sum(dim=(1,2,3))
1640
+ # unet_l2 = ((unet_pred_latent.float()-unet_mean.float())).pow(2).sum(dim=(1,2,3)).sqrt()
1641
+ # l2_error = (((preview_latent[latents.shape[0]:]-previewer_mean) - (unet_pred_latent-unet_mean))).pow(2).mean(dim=(1,2,3))
1642
+ # preview_error = torch.nn.functional.cosine_similarity(preview_latent[latents.shape[0]:].reshape(latents.shape[0], -1), unet_pred_latent.reshape(latents.shape[0],-1))
1643
+ previewer_mean = preview_latent[latents.shape[0]:]
1644
+ unet_mean = unet_pred_latent
1645
+ preview_factor = (pred_x0_l2 / previewer_l2).reshape(-1, 1, 1, 1)
1646
+
1647
+ if latents.dtype != latents_dtype:
1648
+ if torch.backends.mps.is_available():
1649
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1650
+ latents = latents.to(latents_dtype)
1651
+
1652
+ if callback_on_step_end is not None:
1653
+ callback_kwargs = {}
1654
+ for k in callback_on_step_end_tensor_inputs:
1655
+ callback_kwargs[k] = locals()[k]
1656
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1657
+
1658
+ latents = callback_outputs.pop("latents", latents)
1659
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1660
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1661
+
1662
+ # call the callback, if provided
1663
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1664
+ progress_bar.update()
1665
+ if callback is not None and i % callback_steps == 0:
1666
+ step_idx = i // getattr(self.scheduler, "order", 1)
1667
+ callback(step_idx, t, latents)
1668
+
1669
+ if not output_type == "latent":
1670
+ # make sure the VAE is in float32 mode, as it overflows in float16
1671
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1672
+
1673
+ if needs_upcasting:
1674
+ self.upcast_vae()
1675
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1676
+
1677
+ # unscale/denormalize the latents
1678
+ # denormalize with the mean and std if available and not None
1679
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1680
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1681
+ if has_latents_mean and has_latents_std:
1682
+ latents_mean = (
1683
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1684
+ )
1685
+ latents_std = (
1686
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1687
+ )
1688
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1689
+ else:
1690
+ latents = latents / self.vae.config.scaling_factor
1691
+
1692
+ image = self.vae.decode(latents, return_dict=False)[0]
1693
+
1694
+ # cast back to fp16 if needed
1695
+ if needs_upcasting:
1696
+ self.vae.to(dtype=torch.float16)
1697
+ else:
1698
+ image = latents
1699
+
1700
+ if not output_type == "latent":
1701
+ # apply watermark if available
1702
+ if self.watermark is not None:
1703
+ image = self.watermark.apply_watermark(image)
1704
+
1705
+ image = self.image_processor.postprocess(image, output_type=output_type)
1706
+
1707
+ if save_preview_row:
1708
+ preview_image_row = []
1709
+ if needs_upcasting:
1710
+ self.upcast_vae()
1711
+ for preview_latents in preview_row:
1712
+ preview_latents = preview_latents.to(device=self.device, dtype=next(iter(self.vae.post_quant_conv.parameters())).dtype)
1713
+ if has_latents_mean and has_latents_std:
1714
+ latents_mean = (
1715
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype)
1716
+ )
1717
+ latents_std = (
1718
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(preview_latents.device, preview_latents.dtype)
1719
+ )
1720
+ preview_latents = preview_latents * latents_std / self.vae.config.scaling_factor + latents_mean
1721
+ else:
1722
+ preview_latents = preview_latents / self.vae.config.scaling_factor
1723
+
1724
+ preview_image = self.vae.decode(preview_latents, return_dict=False)[0]
1725
+ preview_image = self.image_processor.postprocess(preview_image, output_type=output_type)
1726
+ preview_image_row.append(preview_image)
1727
+
1728
+ # cast back to fp16 if needed
1729
+ if needs_upcasting:
1730
+ self.vae.to(dtype=torch.float16)
1731
+
1732
+ # Offload all models
1733
+ self.maybe_free_model_hooks()
1734
+
1735
+ if not return_dict:
1736
+ if save_preview_row:
1737
+ return (image, preview_image_row)
1738
+ return (image,)
1739
+
1740
+ return StableDiffusionXLPipelineOutput(images=image)
pipelines/stage1_sdxl_pipeline.py ADDED
@@ -0,0 +1,1283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from transformers import (
20
+ CLIPImageProcessor,
21
+ CLIPTextModel,
22
+ CLIPTextModelWithProjection,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ )
26
+
27
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
28
+ from ...loaders import (
29
+ FromSingleFileMixin,
30
+ IPAdapterMixin,
31
+ StableDiffusionXLLoraLoaderMixin,
32
+ TextualInversionLoaderMixin,
33
+ )
34
+ from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
35
+ from ...models.attention_processor import (
36
+ AttnProcessor2_0,
37
+ FusedAttnProcessor2_0,
38
+ LoRAAttnProcessor2_0,
39
+ LoRAXFormersAttnProcessor,
40
+ XFormersAttnProcessor,
41
+ )
42
+ from ...models.lora import adjust_lora_scale_text_encoder
43
+ from ...schedulers import KarrasDiffusionSchedulers
44
+ from ...utils import (
45
+ USE_PEFT_BACKEND,
46
+ deprecate,
47
+ is_invisible_watermark_available,
48
+ is_torch_xla_available,
49
+ logging,
50
+ replace_example_docstring,
51
+ scale_lora_layers,
52
+ unscale_lora_layers,
53
+ )
54
+ from ...utils.torch_utils import randn_tensor
55
+ from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
56
+ from .pipeline_output import StableDiffusionXLPipelineOutput
57
+
58
+
59
+ if is_invisible_watermark_available():
60
+ from .watermark import StableDiffusionXLWatermarker
61
+
62
+ if is_torch_xla_available():
63
+ import torch_xla.core.xla_model as xm
64
+
65
+ XLA_AVAILABLE = True
66
+ else:
67
+ XLA_AVAILABLE = False
68
+
69
+
70
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
71
+
72
+ EXAMPLE_DOC_STRING = """
73
+ Examples:
74
+ ```py
75
+ >>> import torch
76
+ >>> from diffusers import StableDiffusionXLPipeline
77
+
78
+ >>> pipe = StableDiffusionXLPipeline.from_pretrained(
79
+ ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
80
+ ... )
81
+ >>> pipe = pipe.to("cuda")
82
+
83
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
84
+ >>> image = pipe(prompt).images[0]
85
+ ```
86
+ """
87
+
88
+
89
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
90
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
91
+ """
92
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
93
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
94
+ """
95
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
96
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
97
+ # rescale the results from guidance (fixes overexposure)
98
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
99
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
100
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
101
+ return noise_cfg
102
+
103
+
104
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
105
+ def retrieve_timesteps(
106
+ scheduler,
107
+ num_inference_steps: Optional[int] = None,
108
+ device: Optional[Union[str, torch.device]] = None,
109
+ timesteps: Optional[List[int]] = None,
110
+ **kwargs,
111
+ ):
112
+ """
113
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
114
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
115
+
116
+ Args:
117
+ scheduler (`SchedulerMixin`):
118
+ The scheduler to get timesteps from.
119
+ num_inference_steps (`int`):
120
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
121
+ must be `None`.
122
+ device (`str` or `torch.device`, *optional*):
123
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
124
+ timesteps (`List[int]`, *optional*):
125
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
126
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
127
+ must be `None`.
128
+
129
+ Returns:
130
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
131
+ second element is the number of inference steps.
132
+ """
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ else:
144
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
145
+ timesteps = scheduler.timesteps
146
+ return timesteps, num_inference_steps
147
+
148
+
149
+ class StableDiffusionXLPipeline(
150
+ DiffusionPipeline,
151
+ StableDiffusionMixin,
152
+ FromSingleFileMixin,
153
+ StableDiffusionXLLoraLoaderMixin,
154
+ TextualInversionLoaderMixin,
155
+ IPAdapterMixin,
156
+ ):
157
+ r"""
158
+ Pipeline for text-to-image generation using Stable Diffusion XL.
159
+
160
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
161
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
162
+
163
+ The pipeline also inherits the following loading methods:
164
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
165
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
166
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
167
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
168
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
169
+
170
+ Args:
171
+ vae ([`AutoencoderKL`]):
172
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
173
+ text_encoder ([`CLIPTextModel`]):
174
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
175
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
176
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
177
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
178
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
179
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
180
+ specifically the
181
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
182
+ variant.
183
+ tokenizer (`CLIPTokenizer`):
184
+ Tokenizer of class
185
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
186
+ tokenizer_2 (`CLIPTokenizer`):
187
+ Second Tokenizer of class
188
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
189
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
190
+ scheduler ([`SchedulerMixin`]):
191
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
192
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
193
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
194
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
195
+ `stabilityai/stable-diffusion-xl-base-1-0`.
196
+ add_watermarker (`bool`, *optional*):
197
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
198
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
199
+ watermarker will be used.
200
+ """
201
+
202
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
203
+ _optional_components = [
204
+ "tokenizer",
205
+ "tokenizer_2",
206
+ "text_encoder",
207
+ "text_encoder_2",
208
+ "image_encoder",
209
+ "feature_extractor",
210
+ ]
211
+ _callback_tensor_inputs = [
212
+ "latents",
213
+ "prompt_embeds",
214
+ "negative_prompt_embeds",
215
+ "add_text_embeds",
216
+ "add_time_ids",
217
+ "negative_pooled_prompt_embeds",
218
+ "negative_add_time_ids",
219
+ ]
220
+
221
+ def __init__(
222
+ self,
223
+ vae: AutoencoderKL,
224
+ text_encoder: CLIPTextModel,
225
+ text_encoder_2: CLIPTextModelWithProjection,
226
+ tokenizer: CLIPTokenizer,
227
+ tokenizer_2: CLIPTokenizer,
228
+ unet: UNet2DConditionModel,
229
+ scheduler: KarrasDiffusionSchedulers,
230
+ image_encoder: CLIPVisionModelWithProjection = None,
231
+ feature_extractor: CLIPImageProcessor = None,
232
+ force_zeros_for_empty_prompt: bool = True,
233
+ add_watermarker: Optional[bool] = None,
234
+ ):
235
+ super().__init__()
236
+
237
+ self.register_modules(
238
+ vae=vae,
239
+ text_encoder=text_encoder,
240
+ text_encoder_2=text_encoder_2,
241
+ tokenizer=tokenizer,
242
+ tokenizer_2=tokenizer_2,
243
+ unet=unet,
244
+ scheduler=scheduler,
245
+ image_encoder=image_encoder,
246
+ feature_extractor=feature_extractor,
247
+ )
248
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
249
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
250
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
251
+
252
+ self.default_sample_size = self.unet.config.sample_size
253
+
254
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
255
+
256
+ if add_watermarker:
257
+ self.watermark = StableDiffusionXLWatermarker()
258
+ else:
259
+ self.watermark = None
260
+
261
+ def encode_prompt(
262
+ self,
263
+ prompt: str,
264
+ prompt_2: Optional[str] = None,
265
+ device: Optional[torch.device] = None,
266
+ num_images_per_prompt: int = 1,
267
+ do_classifier_free_guidance: bool = True,
268
+ negative_prompt: Optional[str] = None,
269
+ negative_prompt_2: Optional[str] = None,
270
+ prompt_embeds: Optional[torch.FloatTensor] = None,
271
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
272
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
273
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
274
+ lora_scale: Optional[float] = None,
275
+ clip_skip: Optional[int] = None,
276
+ ):
277
+ r"""
278
+ Encodes the prompt into text encoder hidden states.
279
+
280
+ Args:
281
+ prompt (`str` or `List[str]`, *optional*):
282
+ prompt to be encoded
283
+ prompt_2 (`str` or `List[str]`, *optional*):
284
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
285
+ used in both text-encoders
286
+ device: (`torch.device`):
287
+ torch device
288
+ num_images_per_prompt (`int`):
289
+ number of images that should be generated per prompt
290
+ do_classifier_free_guidance (`bool`):
291
+ whether to use classifier free guidance or not
292
+ negative_prompt (`str` or `List[str]`, *optional*):
293
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
294
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
295
+ less than `1`).
296
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
297
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
298
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
299
+ prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
301
+ provided, text embeddings will be generated from `prompt` input argument.
302
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
303
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
304
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
305
+ argument.
306
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
307
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
308
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
309
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
310
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
311
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
312
+ input argument.
313
+ lora_scale (`float`, *optional*):
314
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
315
+ clip_skip (`int`, *optional*):
316
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
317
+ the output of the pre-final layer will be used for computing the prompt embeddings.
318
+ """
319
+ device = device or self._execution_device
320
+
321
+ # set lora scale so that monkey patched LoRA
322
+ # function of text encoder can correctly access it
323
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
324
+ self._lora_scale = lora_scale
325
+
326
+ # dynamically adjust the LoRA scale
327
+ if self.text_encoder is not None:
328
+ if not USE_PEFT_BACKEND:
329
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
330
+ else:
331
+ scale_lora_layers(self.text_encoder, lora_scale)
332
+
333
+ if self.text_encoder_2 is not None:
334
+ if not USE_PEFT_BACKEND:
335
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
336
+ else:
337
+ scale_lora_layers(self.text_encoder_2, lora_scale)
338
+
339
+ prompt = [prompt] if isinstance(prompt, str) else prompt
340
+
341
+ if prompt is not None:
342
+ batch_size = len(prompt)
343
+ else:
344
+ batch_size = prompt_embeds.shape[0]
345
+
346
+ # Define tokenizers and text encoders
347
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
348
+ text_encoders = (
349
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
350
+ )
351
+
352
+ if prompt_embeds is None:
353
+ prompt_2 = prompt_2 or prompt
354
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
355
+
356
+ # textual inversion: process multi-vector tokens if necessary
357
+ prompt_embeds_list = []
358
+ prompts = [prompt, prompt_2]
359
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
360
+ if isinstance(self, TextualInversionLoaderMixin):
361
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
362
+
363
+ text_inputs = tokenizer(
364
+ prompt,
365
+ padding="max_length",
366
+ max_length=tokenizer.model_max_length,
367
+ truncation=True,
368
+ return_tensors="pt",
369
+ )
370
+
371
+ text_input_ids = text_inputs.input_ids
372
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
373
+
374
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
375
+ text_input_ids, untruncated_ids
376
+ ):
377
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
378
+ logger.warning(
379
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
380
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
381
+ )
382
+
383
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
384
+
385
+ # We are only ALWAYS interested in the pooled output of the final text encoder
386
+ pooled_prompt_embeds = prompt_embeds[0]
387
+ if clip_skip is None:
388
+ prompt_embeds = prompt_embeds.hidden_states[-2]
389
+ else:
390
+ # "2" because SDXL always indexes from the penultimate layer.
391
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
392
+
393
+ prompt_embeds_list.append(prompt_embeds)
394
+
395
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
396
+
397
+ # get unconditional embeddings for classifier free guidance
398
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
399
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
400
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
401
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
402
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
403
+ negative_prompt = negative_prompt or ""
404
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
405
+
406
+ # normalize str to list
407
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
408
+ negative_prompt_2 = (
409
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
410
+ )
411
+
412
+ uncond_tokens: List[str]
413
+ if prompt is not None and type(prompt) is not type(negative_prompt):
414
+ raise TypeError(
415
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
416
+ f" {type(prompt)}."
417
+ )
418
+ elif batch_size != len(negative_prompt):
419
+ raise ValueError(
420
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
421
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
422
+ " the batch size of `prompt`."
423
+ )
424
+ else:
425
+ uncond_tokens = [negative_prompt, negative_prompt_2]
426
+
427
+ negative_prompt_embeds_list = []
428
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
429
+ if isinstance(self, TextualInversionLoaderMixin):
430
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
431
+
432
+ max_length = prompt_embeds.shape[1]
433
+ uncond_input = tokenizer(
434
+ negative_prompt,
435
+ padding="max_length",
436
+ max_length=max_length,
437
+ truncation=True,
438
+ return_tensors="pt",
439
+ )
440
+
441
+ negative_prompt_embeds = text_encoder(
442
+ uncond_input.input_ids.to(device),
443
+ output_hidden_states=True,
444
+ )
445
+ # We are only ALWAYS interested in the pooled output of the final text encoder
446
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
447
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
448
+
449
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
450
+
451
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
452
+
453
+ if self.text_encoder_2 is not None:
454
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
455
+ else:
456
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
457
+
458
+ bs_embed, seq_len, _ = prompt_embeds.shape
459
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
460
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
461
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
462
+
463
+ if do_classifier_free_guidance:
464
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
465
+ seq_len = negative_prompt_embeds.shape[1]
466
+
467
+ if self.text_encoder_2 is not None:
468
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
469
+ else:
470
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
471
+
472
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
473
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
474
+
475
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
476
+ bs_embed * num_images_per_prompt, -1
477
+ )
478
+ if do_classifier_free_guidance:
479
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
480
+ bs_embed * num_images_per_prompt, -1
481
+ )
482
+
483
+ if self.text_encoder is not None:
484
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
485
+ # Retrieve the original scale by scaling back the LoRA layers
486
+ unscale_lora_layers(self.text_encoder, lora_scale)
487
+
488
+ if self.text_encoder_2 is not None:
489
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
490
+ # Retrieve the original scale by scaling back the LoRA layers
491
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
492
+
493
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
494
+
495
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
496
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
497
+ dtype = next(self.image_encoder.parameters()).dtype
498
+
499
+ if not isinstance(image, torch.Tensor):
500
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
501
+
502
+ image = image.to(device=device, dtype=dtype)
503
+ if output_hidden_states:
504
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
505
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
506
+ uncond_image_enc_hidden_states = self.image_encoder(
507
+ torch.zeros_like(image), output_hidden_states=True
508
+ ).hidden_states[-2]
509
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
510
+ num_images_per_prompt, dim=0
511
+ )
512
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
513
+ else:
514
+ image_embeds = self.image_encoder(image).image_embeds
515
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
516
+ uncond_image_embeds = torch.zeros_like(image_embeds)
517
+
518
+ return image_embeds, uncond_image_embeds
519
+
520
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
521
+ def prepare_ip_adapter_image_embeds(
522
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
523
+ ):
524
+ if ip_adapter_image_embeds is None:
525
+ if not isinstance(ip_adapter_image, list):
526
+ ip_adapter_image = [ip_adapter_image]
527
+
528
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
529
+ raise ValueError(
530
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
531
+ )
532
+
533
+ image_embeds = []
534
+ for single_ip_adapter_image, image_proj_layer in zip(
535
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
536
+ ):
537
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
538
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
539
+ single_ip_adapter_image, device, 1, output_hidden_state
540
+ )
541
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
542
+ single_negative_image_embeds = torch.stack(
543
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
544
+ )
545
+
546
+ if do_classifier_free_guidance:
547
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
548
+ single_image_embeds = single_image_embeds.to(device)
549
+
550
+ image_embeds.append(single_image_embeds)
551
+ else:
552
+ repeat_dims = [1]
553
+ image_embeds = []
554
+ for single_image_embeds in ip_adapter_image_embeds:
555
+ if do_classifier_free_guidance:
556
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
557
+ single_image_embeds = single_image_embeds.repeat(
558
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
559
+ )
560
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
561
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
562
+ )
563
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
564
+ else:
565
+ single_image_embeds = single_image_embeds.repeat(
566
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
567
+ )
568
+ image_embeds.append(single_image_embeds)
569
+
570
+ return image_embeds
571
+
572
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
573
+ def prepare_extra_step_kwargs(self, generator, eta):
574
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
575
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
576
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
577
+ # and should be between [0, 1]
578
+
579
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
580
+ extra_step_kwargs = {}
581
+ if accepts_eta:
582
+ extra_step_kwargs["eta"] = eta
583
+
584
+ # check if the scheduler accepts generator
585
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
586
+ if accepts_generator:
587
+ extra_step_kwargs["generator"] = generator
588
+ return extra_step_kwargs
589
+
590
+ def check_inputs(
591
+ self,
592
+ prompt,
593
+ prompt_2,
594
+ height,
595
+ width,
596
+ callback_steps,
597
+ negative_prompt=None,
598
+ negative_prompt_2=None,
599
+ prompt_embeds=None,
600
+ negative_prompt_embeds=None,
601
+ pooled_prompt_embeds=None,
602
+ negative_pooled_prompt_embeds=None,
603
+ ip_adapter_image=None,
604
+ ip_adapter_image_embeds=None,
605
+ callback_on_step_end_tensor_inputs=None,
606
+ ):
607
+ if height % 8 != 0 or width % 8 != 0:
608
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
609
+
610
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
611
+ raise ValueError(
612
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
613
+ f" {type(callback_steps)}."
614
+ )
615
+
616
+ if callback_on_step_end_tensor_inputs is not None and not all(
617
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
618
+ ):
619
+ raise ValueError(
620
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
621
+ )
622
+
623
+ if prompt is not None and prompt_embeds is not None:
624
+ raise ValueError(
625
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
626
+ " only forward one of the two."
627
+ )
628
+ elif prompt_2 is not None and prompt_embeds is not None:
629
+ raise ValueError(
630
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
631
+ " only forward one of the two."
632
+ )
633
+ elif prompt is None and prompt_embeds is None:
634
+ raise ValueError(
635
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
636
+ )
637
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
638
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
639
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
640
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
641
+
642
+ if negative_prompt is not None and negative_prompt_embeds is not None:
643
+ raise ValueError(
644
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
645
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
646
+ )
647
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
648
+ raise ValueError(
649
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
650
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
651
+ )
652
+
653
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
654
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
655
+ raise ValueError(
656
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
657
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
658
+ f" {negative_prompt_embeds.shape}."
659
+ )
660
+
661
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
662
+ raise ValueError(
663
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
664
+ )
665
+
666
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
667
+ raise ValueError(
668
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
669
+ )
670
+
671
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
672
+ raise ValueError(
673
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
674
+ )
675
+
676
+ if ip_adapter_image_embeds is not None:
677
+ if not isinstance(ip_adapter_image_embeds, list):
678
+ raise ValueError(
679
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
680
+ )
681
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
682
+ raise ValueError(
683
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
684
+ )
685
+
686
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
687
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
688
+ shape = (
689
+ batch_size,
690
+ num_channels_latents,
691
+ int(height) // self.vae_scale_factor,
692
+ int(width) // self.vae_scale_factor,
693
+ )
694
+ if isinstance(generator, list) and len(generator) != batch_size:
695
+ raise ValueError(
696
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
697
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
698
+ )
699
+
700
+ if latents is None:
701
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
702
+ else:
703
+ latents = latents.to(device)
704
+
705
+ # scale the initial noise by the standard deviation required by the scheduler
706
+ latents = latents * self.scheduler.init_noise_sigma
707
+ return latents
708
+
709
+ def _get_add_time_ids(
710
+ self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
711
+ ):
712
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
713
+
714
+ passed_add_embed_dim = (
715
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
716
+ )
717
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
718
+
719
+ if expected_add_embed_dim != passed_add_embed_dim:
720
+ raise ValueError(
721
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
722
+ )
723
+
724
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
725
+ return add_time_ids
726
+
727
+ def upcast_vae(self):
728
+ dtype = self.vae.dtype
729
+ self.vae.to(dtype=torch.float32)
730
+ use_torch_2_0_or_xformers = isinstance(
731
+ self.vae.decoder.mid_block.attentions[0].processor,
732
+ (
733
+ AttnProcessor2_0,
734
+ XFormersAttnProcessor,
735
+ LoRAXFormersAttnProcessor,
736
+ LoRAAttnProcessor2_0,
737
+ FusedAttnProcessor2_0,
738
+ ),
739
+ )
740
+ # if xformers or torch_2_0 is used attention block does not need
741
+ # to be in float32 which can save lots of memory
742
+ if use_torch_2_0_or_xformers:
743
+ self.vae.post_quant_conv.to(dtype)
744
+ self.vae.decoder.conv_in.to(dtype)
745
+ self.vae.decoder.mid_block.to(dtype)
746
+
747
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
748
+ def get_guidance_scale_embedding(
749
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
750
+ ) -> torch.FloatTensor:
751
+ """
752
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
753
+
754
+ Args:
755
+ w (`torch.Tensor`):
756
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
757
+ embedding_dim (`int`, *optional*, defaults to 512):
758
+ Dimension of the embeddings to generate.
759
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
760
+ Data type of the generated embeddings.
761
+
762
+ Returns:
763
+ `torch.FloatTensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
764
+ """
765
+ assert len(w.shape) == 1
766
+ w = w * 1000.0
767
+
768
+ half_dim = embedding_dim // 2
769
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
770
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
771
+ emb = w.to(dtype)[:, None] * emb[None, :]
772
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
773
+ if embedding_dim % 2 == 1: # zero pad
774
+ emb = torch.nn.functional.pad(emb, (0, 1))
775
+ assert emb.shape == (w.shape[0], embedding_dim)
776
+ return emb
777
+
778
+ @property
779
+ def guidance_scale(self):
780
+ return self._guidance_scale
781
+
782
+ @property
783
+ def guidance_rescale(self):
784
+ return self._guidance_rescale
785
+
786
+ @property
787
+ def clip_skip(self):
788
+ return self._clip_skip
789
+
790
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
791
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
792
+ # corresponds to doing no classifier free guidance.
793
+ @property
794
+ def do_classifier_free_guidance(self):
795
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
796
+
797
+ @property
798
+ def cross_attention_kwargs(self):
799
+ return self._cross_attention_kwargs
800
+
801
+ @property
802
+ def denoising_end(self):
803
+ return self._denoising_end
804
+
805
+ @property
806
+ def num_timesteps(self):
807
+ return self._num_timesteps
808
+
809
+ @property
810
+ def interrupt(self):
811
+ return self._interrupt
812
+
813
+ @torch.no_grad()
814
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
815
+ def __call__(
816
+ self,
817
+ prompt: Union[str, List[str]] = None,
818
+ prompt_2: Optional[Union[str, List[str]]] = None,
819
+ height: Optional[int] = None,
820
+ width: Optional[int] = None,
821
+ num_inference_steps: int = 50,
822
+ timesteps: List[int] = None,
823
+ denoising_end: Optional[float] = None,
824
+ guidance_scale: float = 5.0,
825
+ negative_prompt: Optional[Union[str, List[str]]] = None,
826
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
827
+ num_images_per_prompt: Optional[int] = 1,
828
+ eta: float = 0.0,
829
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
830
+ latents: Optional[torch.FloatTensor] = None,
831
+ prompt_embeds: Optional[torch.FloatTensor] = None,
832
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
833
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
834
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
835
+ ip_adapter_image: Optional[PipelineImageInput] = None,
836
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
837
+ output_type: Optional[str] = "pil",
838
+ return_dict: bool = True,
839
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
840
+ guidance_rescale: float = 0.0,
841
+ original_size: Optional[Tuple[int, int]] = None,
842
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
843
+ target_size: Optional[Tuple[int, int]] = None,
844
+ negative_original_size: Optional[Tuple[int, int]] = None,
845
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
846
+ negative_target_size: Optional[Tuple[int, int]] = None,
847
+ clip_skip: Optional[int] = None,
848
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
849
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
850
+ **kwargs,
851
+ ):
852
+ r"""
853
+ Function invoked when calling the pipeline for generation.
854
+
855
+ Args:
856
+ prompt (`str` or `List[str]`, *optional*):
857
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
858
+ instead.
859
+ prompt_2 (`str` or `List[str]`, *optional*):
860
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
861
+ used in both text-encoders
862
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
863
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
864
+ Anything below 512 pixels won't work well for
865
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
866
+ and checkpoints that are not specifically fine-tuned on low resolutions.
867
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
868
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
869
+ Anything below 512 pixels won't work well for
870
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
871
+ and checkpoints that are not specifically fine-tuned on low resolutions.
872
+ num_inference_steps (`int`, *optional*, defaults to 50):
873
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
874
+ expense of slower inference.
875
+ timesteps (`List[int]`, *optional*):
876
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
877
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
878
+ passed will be used. Must be in descending order.
879
+ denoising_end (`float`, *optional*):
880
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
881
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
882
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
883
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
884
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
885
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
886
+ guidance_scale (`float`, *optional*, defaults to 5.0):
887
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
888
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
889
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
890
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
891
+ usually at the expense of lower image quality.
892
+ negative_prompt (`str` or `List[str]`, *optional*):
893
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
894
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
895
+ less than `1`).
896
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
897
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
898
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
899
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
900
+ The number of images to generate per prompt.
901
+ eta (`float`, *optional*, defaults to 0.0):
902
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
903
+ [`schedulers.DDIMScheduler`], will be ignored for others.
904
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
905
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
906
+ to make generation deterministic.
907
+ latents (`torch.FloatTensor`, *optional*):
908
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
909
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
910
+ tensor will ge generated by sampling using the supplied random `generator`.
911
+ prompt_embeds (`torch.FloatTensor`, *optional*):
912
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
913
+ provided, text embeddings will be generated from `prompt` input argument.
914
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
915
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
916
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
917
+ argument.
918
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
919
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
920
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
921
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
922
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
923
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
924
+ input argument.
925
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
926
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
927
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
928
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
929
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
930
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
931
+ output_type (`str`, *optional*, defaults to `"pil"`):
932
+ The output format of the generate image. Choose between
933
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
934
+ return_dict (`bool`, *optional*, defaults to `True`):
935
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
936
+ of a plain tuple.
937
+ cross_attention_kwargs (`dict`, *optional*):
938
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
939
+ `self.processor` in
940
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
941
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
942
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
943
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
944
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
945
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
946
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
947
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
948
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
949
+ explained in section 2.2 of
950
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
951
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
952
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
953
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
954
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
955
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
956
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
957
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
958
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
959
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
960
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
961
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
962
+ micro-conditioning as explained in section 2.2 of
963
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
964
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
965
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
966
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
967
+ micro-conditioning as explained in section 2.2 of
968
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
969
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
970
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
971
+ To negatively condition the generation process based on a target image resolution. It should be as same
972
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
973
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
974
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
975
+ callback_on_step_end (`Callable`, *optional*):
976
+ A function that calls at the end of each denoising steps during the inference. The function is called
977
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
978
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
979
+ `callback_on_step_end_tensor_inputs`.
980
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
981
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
982
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
983
+ `._callback_tensor_inputs` attribute of your pipeline class.
984
+
985
+ Examples:
986
+
987
+ Returns:
988
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
989
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
990
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
991
+ """
992
+
993
+ callback = kwargs.pop("callback", None)
994
+ callback_steps = kwargs.pop("callback_steps", None)
995
+
996
+ if callback is not None:
997
+ deprecate(
998
+ "callback",
999
+ "1.0.0",
1000
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1001
+ )
1002
+ if callback_steps is not None:
1003
+ deprecate(
1004
+ "callback_steps",
1005
+ "1.0.0",
1006
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1007
+ )
1008
+
1009
+ # 0. Default height and width to unet
1010
+ height = height or self.default_sample_size * self.vae_scale_factor
1011
+ width = width or self.default_sample_size * self.vae_scale_factor
1012
+
1013
+ original_size = original_size or (height, width)
1014
+ target_size = target_size or (height, width)
1015
+
1016
+ # 1. Check inputs. Raise error if not correct
1017
+ self.check_inputs(
1018
+ prompt,
1019
+ prompt_2,
1020
+ height,
1021
+ width,
1022
+ callback_steps,
1023
+ negative_prompt,
1024
+ negative_prompt_2,
1025
+ prompt_embeds,
1026
+ negative_prompt_embeds,
1027
+ pooled_prompt_embeds,
1028
+ negative_pooled_prompt_embeds,
1029
+ ip_adapter_image,
1030
+ ip_adapter_image_embeds,
1031
+ callback_on_step_end_tensor_inputs,
1032
+ )
1033
+
1034
+ self._guidance_scale = guidance_scale
1035
+ self._guidance_rescale = guidance_rescale
1036
+ self._clip_skip = clip_skip
1037
+ self._cross_attention_kwargs = cross_attention_kwargs
1038
+ self._denoising_end = denoising_end
1039
+ self._interrupt = False
1040
+
1041
+ # 2. Define call parameters
1042
+ if prompt is not None and isinstance(prompt, str):
1043
+ batch_size = 1
1044
+ elif prompt is not None and isinstance(prompt, list):
1045
+ batch_size = len(prompt)
1046
+ else:
1047
+ batch_size = prompt_embeds.shape[0]
1048
+
1049
+ device = self._execution_device
1050
+
1051
+ # 3. Encode input prompt
1052
+ lora_scale = (
1053
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1054
+ )
1055
+
1056
+ (
1057
+ prompt_embeds,
1058
+ negative_prompt_embeds,
1059
+ pooled_prompt_embeds,
1060
+ negative_pooled_prompt_embeds,
1061
+ ) = self.encode_prompt(
1062
+ prompt=prompt,
1063
+ prompt_2=prompt_2,
1064
+ device=device,
1065
+ num_images_per_prompt=num_images_per_prompt,
1066
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1067
+ negative_prompt=negative_prompt,
1068
+ negative_prompt_2=negative_prompt_2,
1069
+ prompt_embeds=prompt_embeds,
1070
+ negative_prompt_embeds=negative_prompt_embeds,
1071
+ pooled_prompt_embeds=pooled_prompt_embeds,
1072
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1073
+ lora_scale=lora_scale,
1074
+ clip_skip=self.clip_skip,
1075
+ )
1076
+
1077
+ # 4. Prepare timesteps
1078
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1079
+
1080
+ # 5. Prepare latent variables
1081
+ num_channels_latents = self.unet.config.in_channels
1082
+ latents = self.prepare_latents(
1083
+ batch_size * num_images_per_prompt,
1084
+ num_channels_latents,
1085
+ height,
1086
+ width,
1087
+ prompt_embeds.dtype,
1088
+ device,
1089
+ generator,
1090
+ latents,
1091
+ )
1092
+
1093
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1094
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1095
+
1096
+ # 7. Prepare added time ids & embeddings
1097
+ add_text_embeds = pooled_prompt_embeds
1098
+ if self.text_encoder_2 is None:
1099
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1100
+ else:
1101
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1102
+
1103
+ add_time_ids = self._get_add_time_ids(
1104
+ original_size,
1105
+ crops_coords_top_left,
1106
+ target_size,
1107
+ dtype=prompt_embeds.dtype,
1108
+ text_encoder_projection_dim=text_encoder_projection_dim,
1109
+ )
1110
+ if negative_original_size is not None and negative_target_size is not None:
1111
+ negative_add_time_ids = self._get_add_time_ids(
1112
+ negative_original_size,
1113
+ negative_crops_coords_top_left,
1114
+ negative_target_size,
1115
+ dtype=prompt_embeds.dtype,
1116
+ text_encoder_projection_dim=text_encoder_projection_dim,
1117
+ )
1118
+ else:
1119
+ negative_add_time_ids = add_time_ids
1120
+
1121
+ if self.do_classifier_free_guidance:
1122
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1123
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1124
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1125
+
1126
+ prompt_embeds = prompt_embeds.to(device)
1127
+ add_text_embeds = add_text_embeds.to(device)
1128
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1129
+
1130
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1131
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1132
+ ip_adapter_image,
1133
+ ip_adapter_image_embeds,
1134
+ device,
1135
+ batch_size * num_images_per_prompt,
1136
+ self.do_classifier_free_guidance,
1137
+ )
1138
+
1139
+ # 8. Denoising loop
1140
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1141
+
1142
+ # 8.1 Apply denoising_end
1143
+ if (
1144
+ self.denoising_end is not None
1145
+ and isinstance(self.denoising_end, float)
1146
+ and self.denoising_end > 0
1147
+ and self.denoising_end < 1
1148
+ ):
1149
+ discrete_timestep_cutoff = int(
1150
+ round(
1151
+ self.scheduler.config.num_train_timesteps
1152
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1153
+ )
1154
+ )
1155
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1156
+ timesteps = timesteps[:num_inference_steps]
1157
+
1158
+ # 9. Optionally get Guidance Scale Embedding
1159
+ timestep_cond = None
1160
+ if self.unet.config.time_cond_proj_dim is not None:
1161
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1162
+ timestep_cond = self.get_guidance_scale_embedding(
1163
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1164
+ ).to(device=device, dtype=latents.dtype)
1165
+
1166
+ self._num_timesteps = len(timesteps)
1167
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1168
+ for i, t in enumerate(timesteps):
1169
+ if self.interrupt:
1170
+ continue
1171
+
1172
+ # expand the latents if we are doing classifier free guidance
1173
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1174
+
1175
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1176
+
1177
+ # predict the noise residual
1178
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1179
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1180
+ added_cond_kwargs["image_embeds"] = image_embeds
1181
+
1182
+ noise_pred = self.unet(
1183
+ latent_model_input,
1184
+ t,
1185
+ encoder_hidden_states=prompt_embeds, # [B, 77, 2048]
1186
+ timestep_cond=timestep_cond, # None
1187
+ cross_attention_kwargs=self.cross_attention_kwargs, # None
1188
+ added_cond_kwargs=added_cond_kwargs, # {[B, 1280], [B, 6]}
1189
+ return_dict=False,
1190
+ )[0]
1191
+
1192
+ # perform guidance
1193
+ if self.do_classifier_free_guidance:
1194
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1195
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1196
+
1197
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1198
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1199
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1200
+
1201
+ # compute the previous noisy sample x_t -> x_t-1
1202
+ latents_dtype = latents.dtype
1203
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1204
+ if latents.dtype != latents_dtype:
1205
+ if torch.backends.mps.is_available():
1206
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1207
+ latents = latents.to(latents_dtype)
1208
+
1209
+ if callback_on_step_end is not None:
1210
+ callback_kwargs = {}
1211
+ for k in callback_on_step_end_tensor_inputs:
1212
+ callback_kwargs[k] = locals()[k]
1213
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1214
+
1215
+ latents = callback_outputs.pop("latents", latents)
1216
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1217
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1218
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1219
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1220
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1221
+ )
1222
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1223
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1224
+
1225
+ # call the callback, if provided
1226
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1227
+ progress_bar.update()
1228
+ if callback is not None and i % callback_steps == 0:
1229
+ step_idx = i // getattr(self.scheduler, "order", 1)
1230
+ callback(step_idx, t, latents)
1231
+
1232
+ if XLA_AVAILABLE:
1233
+ xm.mark_step()
1234
+
1235
+ if not output_type == "latent":
1236
+ # make sure the VAE is in float32 mode, as it overflows in float16
1237
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1238
+
1239
+ if needs_upcasting:
1240
+ self.upcast_vae()
1241
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1242
+ elif latents.dtype != self.vae.dtype:
1243
+ if torch.backends.mps.is_available():
1244
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1245
+ self.vae = self.vae.to(latents.dtype)
1246
+
1247
+ # unscale/denormalize the latents
1248
+ # denormalize with the mean and std if available and not None
1249
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1250
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1251
+ if has_latents_mean and has_latents_std:
1252
+ latents_mean = (
1253
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1254
+ )
1255
+ latents_std = (
1256
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1257
+ )
1258
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1259
+ else:
1260
+ latents = latents / self.vae.config.scaling_factor
1261
+
1262
+ image = self.vae.decode(latents, return_dict=False)[0]
1263
+
1264
+ # cast back to fp16 if needed
1265
+ if needs_upcasting:
1266
+ self.vae.to(dtype=torch.float16)
1267
+ else:
1268
+ image = latents
1269
+
1270
+ if not output_type == "latent":
1271
+ # apply watermark if available
1272
+ if self.watermark is not None:
1273
+ image = self.watermark.apply_watermark(image)
1274
+
1275
+ image = self.image_processor.postprocess(image, output_type=output_type)
1276
+
1277
+ # Offload all models
1278
+ self.maybe_free_model_hooks()
1279
+
1280
+ if not return_dict:
1281
+ return (image,)
1282
+
1283
+ return StableDiffusionXLPipelineOutput(images=image)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.28.1
2
+ accelerate==0.25.0
3
+ datasets==2.19.1
4
+ einops==0.8.0
5
+ kornia==0.7.2
6
+ numpy==1.26.4
7
+ opencv-python==4.9.0.80
8
+ peft==0.10.0
9
+ pyrallis==0.3.1
10
+ tokenizers==0.15.2
11
+ torch==2.0.1
12
+ torchvision==0.15.2
13
+ transformers==4.36.2
14
+ gradio==4.44.1
schedulers/lcm_single_step_scheduler.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
16
+ # and https://github.com/hojonathanho/diffusion
17
+
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import BaseOutput, logging
27
+ from diffusers.utils.torch_utils import randn_tensor
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
+
33
+
34
+ @dataclass
35
+ class LCMSingleStepSchedulerOutput(BaseOutput):
36
+ """
37
+ Output class for the scheduler's `step` function output.
38
+
39
+ Args:
40
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
41
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
42
+ `pred_original_sample` can be used to preview progress or for guidance.
43
+ """
44
+
45
+ denoised: Optional[torch.FloatTensor] = None
46
+
47
+
48
+ # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
49
+ def betas_for_alpha_bar(
50
+ num_diffusion_timesteps,
51
+ max_beta=0.999,
52
+ alpha_transform_type="cosine",
53
+ ):
54
+ """
55
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
56
+ (1-beta) over time from t = [0,1].
57
+
58
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
59
+ to that part of the diffusion process.
60
+
61
+
62
+ Args:
63
+ num_diffusion_timesteps (`int`): the number of betas to produce.
64
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
65
+ prevent singularities.
66
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
67
+ Choose from `cosine` or `exp`
68
+
69
+ Returns:
70
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
71
+ """
72
+ if alpha_transform_type == "cosine":
73
+
74
+ def alpha_bar_fn(t):
75
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
76
+
77
+ elif alpha_transform_type == "exp":
78
+
79
+ def alpha_bar_fn(t):
80
+ return math.exp(t * -12.0)
81
+
82
+ else:
83
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
84
+
85
+ betas = []
86
+ for i in range(num_diffusion_timesteps):
87
+ t1 = i / num_diffusion_timesteps
88
+ t2 = (i + 1) / num_diffusion_timesteps
89
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
90
+ return torch.tensor(betas, dtype=torch.float32)
91
+
92
+
93
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
94
+ def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
95
+ """
96
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
97
+
98
+
99
+ Args:
100
+ betas (`torch.FloatTensor`):
101
+ the betas that the scheduler is being initialized with.
102
+
103
+ Returns:
104
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
105
+ """
106
+ # Convert betas to alphas_bar_sqrt
107
+ alphas = 1.0 - betas
108
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
109
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
110
+
111
+ # Store old values.
112
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
113
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
114
+
115
+ # Shift so the last timestep is zero.
116
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
117
+
118
+ # Scale so the first timestep is back to the old value.
119
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
120
+
121
+ # Convert alphas_bar_sqrt to betas
122
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
123
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
124
+ alphas = torch.cat([alphas_bar[0:1], alphas])
125
+ betas = 1 - alphas
126
+
127
+ return betas
128
+
129
+
130
+ class LCMSingleStepScheduler(SchedulerMixin, ConfigMixin):
131
+ """
132
+ `LCMSingleStepScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
133
+ non-Markovian guidance.
134
+
135
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
136
+ attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
137
+ accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
138
+ functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
139
+
140
+ Args:
141
+ num_train_timesteps (`int`, defaults to 1000):
142
+ The number of diffusion steps to train the model.
143
+ beta_start (`float`, defaults to 0.0001):
144
+ The starting `beta` value of inference.
145
+ beta_end (`float`, defaults to 0.02):
146
+ The final `beta` value.
147
+ beta_schedule (`str`, defaults to `"linear"`):
148
+ The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
149
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
150
+ trained_betas (`np.ndarray`, *optional*):
151
+ Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
152
+ original_inference_steps (`int`, *optional*, defaults to 50):
153
+ The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
154
+ will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
155
+ clip_sample (`bool`, defaults to `True`):
156
+ Clip the predicted sample for numerical stability.
157
+ clip_sample_range (`float`, defaults to 1.0):
158
+ The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
159
+ set_alpha_to_one (`bool`, defaults to `True`):
160
+ Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
161
+ there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
162
+ otherwise it uses the alpha value at step 0.
163
+ steps_offset (`int`, defaults to 0):
164
+ An offset added to the inference steps. You can use a combination of `offset=1` and
165
+ `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
166
+ Diffusion.
167
+ prediction_type (`str`, defaults to `epsilon`, *optional*):
168
+ Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
169
+ `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
170
+ Video](https://imagen.research.google/video/paper.pdf) paper).
171
+ thresholding (`bool`, defaults to `False`):
172
+ Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
173
+ as Stable Diffusion.
174
+ dynamic_thresholding_ratio (`float`, defaults to 0.995):
175
+ The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
176
+ sample_max_value (`float`, defaults to 1.0):
177
+ The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
178
+ timestep_spacing (`str`, defaults to `"leading"`):
179
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
180
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
181
+ timestep_scaling (`float`, defaults to 10.0):
182
+ The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
183
+ `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
184
+ error at the default of `10.0` is already pretty small).
185
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
186
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
187
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
188
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
189
+ """
190
+
191
+ order = 1
192
+
193
+ @register_to_config
194
+ def __init__(
195
+ self,
196
+ num_train_timesteps: int = 1000,
197
+ beta_start: float = 0.00085,
198
+ beta_end: float = 0.012,
199
+ beta_schedule: str = "scaled_linear",
200
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
201
+ original_inference_steps: int = 50,
202
+ clip_sample: bool = False,
203
+ clip_sample_range: float = 1.0,
204
+ set_alpha_to_one: bool = True,
205
+ steps_offset: int = 0,
206
+ prediction_type: str = "epsilon",
207
+ thresholding: bool = False,
208
+ dynamic_thresholding_ratio: float = 0.995,
209
+ sample_max_value: float = 1.0,
210
+ timestep_spacing: str = "leading",
211
+ timestep_scaling: float = 10.0,
212
+ rescale_betas_zero_snr: bool = False,
213
+ ):
214
+ if trained_betas is not None:
215
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
216
+ elif beta_schedule == "linear":
217
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
218
+ elif beta_schedule == "scaled_linear":
219
+ # this schedule is very specific to the latent diffusion model.
220
+ self.betas = (
221
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
222
+ )
223
+ elif beta_schedule == "squaredcos_cap_v2":
224
+ # Glide cosine schedule
225
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
226
+ else:
227
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
228
+
229
+ # Rescale for zero SNR
230
+ if rescale_betas_zero_snr:
231
+ self.betas = rescale_zero_terminal_snr(self.betas)
232
+
233
+ self.alphas = 1.0 - self.betas
234
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
235
+
236
+ # At every step in ddim, we are looking into the previous alphas_cumprod
237
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
238
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
239
+ # whether we use the final alpha of the "non-previous" one.
240
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
241
+
242
+ # standard deviation of the initial noise distribution
243
+ self.init_noise_sigma = 1.0
244
+
245
+ # setable values
246
+ self.num_inference_steps = None
247
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
248
+
249
+ self._step_index = None
250
+
251
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
252
+ def _init_step_index(self, timestep):
253
+ if isinstance(timestep, torch.Tensor):
254
+ timestep = timestep.to(self.timesteps.device)
255
+
256
+ index_candidates = (self.timesteps == timestep).nonzero()
257
+
258
+ # The sigma index that is taken for the **very** first `step`
259
+ # is always the second index (or the last index if there is only 1)
260
+ # This way we can ensure we don't accidentally skip a sigma in
261
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
262
+ if len(index_candidates) > 1:
263
+ step_index = index_candidates[1]
264
+ else:
265
+ step_index = index_candidates[0]
266
+
267
+ self._step_index = step_index.item()
268
+
269
+ @property
270
+ def step_index(self):
271
+ return self._step_index
272
+
273
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
274
+ """
275
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
276
+ current timestep.
277
+
278
+ Args:
279
+ sample (`torch.FloatTensor`):
280
+ The input sample.
281
+ timestep (`int`, *optional*):
282
+ The current timestep in the diffusion chain.
283
+ Returns:
284
+ `torch.FloatTensor`:
285
+ A scaled input sample.
286
+ """
287
+ return sample
288
+
289
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
290
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
291
+ """
292
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
293
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
294
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
295
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
296
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
297
+
298
+ https://arxiv.org/abs/2205.11487
299
+ """
300
+ dtype = sample.dtype
301
+ batch_size, channels, *remaining_dims = sample.shape
302
+
303
+ if dtype not in (torch.float32, torch.float64):
304
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
305
+
306
+ # Flatten sample for doing quantile calculation along each image
307
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
308
+
309
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
310
+
311
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
312
+ s = torch.clamp(
313
+ s, min=1, max=self.config.sample_max_value
314
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
315
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
316
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
317
+
318
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
319
+ sample = sample.to(dtype)
320
+
321
+ return sample
322
+
323
+ def set_timesteps(
324
+ self,
325
+ num_inference_steps: int = None,
326
+ device: Union[str, torch.device] = None,
327
+ original_inference_steps: Optional[int] = None,
328
+ strength: int = 1.0,
329
+ timesteps: Optional[list] = None,
330
+ ):
331
+ """
332
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
333
+
334
+ Args:
335
+ num_inference_steps (`int`):
336
+ The number of diffusion steps used when generating samples with a pre-trained model.
337
+ device (`str` or `torch.device`, *optional*):
338
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
339
+ original_inference_steps (`int`, *optional*):
340
+ The original number of inference steps, which will be used to generate a linearly-spaced timestep
341
+ schedule (which is different from the standard `diffusers` implementation). We will then take
342
+ `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
343
+ our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
344
+ """
345
+
346
+ if num_inference_steps is not None and timesteps is not None:
347
+ raise ValueError("Can only pass one of `num_inference_steps` or `custom_timesteps`.")
348
+
349
+ if timesteps is not None:
350
+ for i in range(1, len(timesteps)):
351
+ if timesteps[i] >= timesteps[i - 1]:
352
+ raise ValueError("`custom_timesteps` must be in descending order.")
353
+
354
+ if timesteps[0] >= self.config.num_train_timesteps:
355
+ raise ValueError(
356
+ f"`timesteps` must start before `self.config.train_timesteps`:"
357
+ f" {self.config.num_train_timesteps}."
358
+ )
359
+
360
+ timesteps = np.array(timesteps, dtype=np.int64)
361
+ else:
362
+ if num_inference_steps > self.config.num_train_timesteps:
363
+ raise ValueError(
364
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
365
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
366
+ f" maximal {self.config.num_train_timesteps} timesteps."
367
+ )
368
+
369
+ self.num_inference_steps = num_inference_steps
370
+ original_steps = (
371
+ original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
372
+ )
373
+
374
+ if original_steps > self.config.num_train_timesteps:
375
+ raise ValueError(
376
+ f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
377
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
378
+ f" maximal {self.config.num_train_timesteps} timesteps."
379
+ )
380
+
381
+ if num_inference_steps > original_steps:
382
+ raise ValueError(
383
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
384
+ f" {original_steps} because the final timestep schedule will be a subset of the"
385
+ f" `original_inference_steps`-sized initial timestep schedule."
386
+ )
387
+
388
+ # LCM Timesteps Setting
389
+ # Currently, only linear spacing is supported.
390
+ c = self.config.num_train_timesteps // original_steps
391
+ # LCM Training Steps Schedule
392
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
393
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
394
+ # LCM Inference Steps Schedule
395
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
396
+
397
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device=device, dtype=torch.long)
398
+
399
+ self._step_index = None
400
+
401
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
402
+ self.sigma_data = 0.5 # Default: 0.5
403
+ scaled_timestep = timestep * self.config.timestep_scaling
404
+
405
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
406
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
407
+ return c_skip, c_out
408
+
409
+ def append_dims(self, x, target_dims):
410
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
411
+ dims_to_append = target_dims - x.ndim
412
+ if dims_to_append < 0:
413
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
414
+ return x[(...,) + (None,) * dims_to_append]
415
+
416
+ def extract_into_tensor(self, a, t, x_shape):
417
+ b, *_ = t.shape
418
+ out = a.gather(-1, t)
419
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
420
+
421
+ def step(
422
+ self,
423
+ model_output: torch.FloatTensor,
424
+ timestep: torch.Tensor,
425
+ sample: torch.FloatTensor,
426
+ generator: Optional[torch.Generator] = None,
427
+ return_dict: bool = True,
428
+ ) -> Union[LCMSingleStepSchedulerOutput, Tuple]:
429
+ """
430
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
431
+ process from the learned model outputs (most often the predicted noise).
432
+
433
+ Args:
434
+ model_output (`torch.FloatTensor`):
435
+ The direct output from learned diffusion model.
436
+ timestep (`float`):
437
+ The current discrete timestep in the diffusion chain.
438
+ sample (`torch.FloatTensor`):
439
+ A current instance of a sample created by the diffusion process.
440
+ generator (`torch.Generator`, *optional*):
441
+ A random number generator.
442
+ return_dict (`bool`, *optional*, defaults to `True`):
443
+ Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
444
+ Returns:
445
+ [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
446
+ If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
447
+ tuple is returned where the first element is the sample tensor.
448
+ """
449
+ # 0. make sure everything is on the same device
450
+ alphas_cumprod = self.alphas_cumprod.to(sample.device)
451
+
452
+ # 1. compute alphas, betas
453
+ if timestep.ndim == 0:
454
+ timestep = timestep.unsqueeze(0)
455
+ alpha_prod_t = self.extract_into_tensor(alphas_cumprod, timestep, sample.shape)
456
+ beta_prod_t = 1 - alpha_prod_t
457
+
458
+ # 2. Get scalings for boundary conditions
459
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
460
+ c_skip, c_out = [self.append_dims(x, sample.ndim) for x in [c_skip, c_out]]
461
+
462
+ # 3. Compute the predicted original sample x_0 based on the model parameterization
463
+ if self.config.prediction_type == "epsilon": # noise-prediction
464
+ predicted_original_sample = (sample - torch.sqrt(beta_prod_t) * model_output) / torch.sqrt(alpha_prod_t)
465
+ elif self.config.prediction_type == "sample": # x-prediction
466
+ predicted_original_sample = model_output
467
+ elif self.config.prediction_type == "v_prediction": # v-prediction
468
+ predicted_original_sample = torch.sqrt(alpha_prod_t) * sample - torch.sqrt(beta_prod_t) * model_output
469
+ else:
470
+ raise ValueError(
471
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
472
+ " `v_prediction` for `LCMScheduler`."
473
+ )
474
+
475
+ # 4. Clip or threshold "predicted x_0"
476
+ if self.config.thresholding:
477
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
478
+ elif self.config.clip_sample:
479
+ predicted_original_sample = predicted_original_sample.clamp(
480
+ -self.config.clip_sample_range, self.config.clip_sample_range
481
+ )
482
+
483
+ # 5. Denoise model output using boundary conditions
484
+ denoised = c_out * predicted_original_sample + c_skip * sample
485
+
486
+ if not return_dict:
487
+ return (denoised, )
488
+
489
+ return LCMSingleStepSchedulerOutput(denoised=denoised)
490
+
491
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
492
+ def add_noise(
493
+ self,
494
+ original_samples: torch.FloatTensor,
495
+ noise: torch.FloatTensor,
496
+ timesteps: torch.IntTensor,
497
+ ) -> torch.FloatTensor:
498
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
499
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
500
+ timesteps = timesteps.to(original_samples.device)
501
+
502
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
503
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
504
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
505
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
506
+
507
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
508
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
509
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
510
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
511
+
512
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
513
+ return noisy_samples
514
+
515
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
516
+ def get_velocity(
517
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
518
+ ) -> torch.FloatTensor:
519
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
520
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
521
+ timesteps = timesteps.to(sample.device)
522
+
523
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
524
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
525
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
526
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
527
+
528
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
529
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
530
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
531
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
532
+
533
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
534
+ return velocity
535
+
536
+ def __len__(self):
537
+ return self.config.num_train_timesteps
train_previewer_lora.py ADDED
@@ -0,0 +1,1712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The LCM team and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import copy
18
+ import functools
19
+ import gc
20
+ import logging
21
+ import pyrallis
22
+ import math
23
+ import os
24
+ import random
25
+ import shutil
26
+ from contextlib import nullcontext
27
+ from pathlib import Path
28
+
29
+ import accelerate
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn.functional as F
33
+ import torch.utils.checkpoint
34
+ import transformers
35
+ from PIL import Image
36
+ from accelerate import Accelerator
37
+ from accelerate.logging import get_logger
38
+ from accelerate.utils import ProjectConfiguration, set_seed
39
+ from datasets import load_dataset
40
+ from huggingface_hub import create_repo, upload_folder
41
+ from packaging import version
42
+ from collections import namedtuple
43
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
44
+ from torchvision import transforms
45
+ from torchvision.transforms.functional import crop
46
+ from tqdm.auto import tqdm
47
+ from transformers import (
48
+ AutoTokenizer,
49
+ PretrainedConfig,
50
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
51
+ AutoImageProcessor, AutoModel
52
+ )
53
+
54
+ import diffusers
55
+ from diffusers import (
56
+ AutoencoderKL,
57
+ DDPMScheduler,
58
+ LCMScheduler,
59
+ StableDiffusionXLPipeline,
60
+ UNet2DConditionModel,
61
+ )
62
+ from diffusers.optimization import get_scheduler
63
+ from diffusers.training_utils import cast_training_params, resolve_interpolation_mode
64
+ from diffusers.utils import (
65
+ check_min_version,
66
+ convert_state_dict_to_diffusers,
67
+ convert_unet_state_dict_to_peft,
68
+ is_wandb_available,
69
+ )
70
+ from diffusers.utils.import_utils import is_xformers_available
71
+ from diffusers.utils.torch_utils import is_compiled_module
72
+
73
+ from basicsr.utils.degradation_pipeline import RealESRGANDegradation
74
+ from utils.train_utils import (
75
+ seperate_ip_params_from_unet,
76
+ import_model_class_from_model_name_or_path,
77
+ tensor_to_pil,
78
+ get_train_dataset, prepare_train_dataset, collate_fn,
79
+ encode_prompt, importance_sampling_fn, extract_into_tensor
80
+
81
+ )
82
+ from data.data_config import DataConfig
83
+ from losses.loss_config import LossesConfig
84
+ from losses.losses import *
85
+
86
+ from module.ip_adapter.resampler import Resampler
87
+ from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
88
+
89
+
90
+ if is_wandb_available():
91
+ import wandb
92
+
93
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
94
+
95
+ logger = get_logger(__name__)
96
+
97
+
98
+ def prepare_latents(lq, vae, scheduler, generator, timestep):
99
+ transform = transforms.Compose([
100
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
101
+ transforms.CenterCrop(args.resolution),
102
+ transforms.ToTensor(),
103
+ ])
104
+ lq_pt = [transform(lq_pil.convert("RGB")) for lq_pil in lq]
105
+ img_pt = torch.stack(lq_pt).to(vae.device, dtype=vae.dtype)
106
+ img_pt = img_pt * 2.0 - 1.0
107
+ with torch.no_grad():
108
+ latents = vae.encode(img_pt).latent_dist.sample()
109
+ latents = latents * vae.config.scaling_factor
110
+ noise = torch.randn(latents.shape, generator=generator, device=vae.device, dtype=vae.dtype, layout=torch.strided).to(vae.device)
111
+ bsz = latents.shape[0]
112
+ print(f"init latent at {timestep}")
113
+ timestep = torch.tensor([timestep]*bsz, device=vae.device, dtype=torch.int64)
114
+ latents = scheduler.add_noise(latents, noise, timestep)
115
+ return latents
116
+
117
+
118
+ def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
119
+ scheduler, image_encoder, image_processor,
120
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
121
+ logger.info("Running validation... ")
122
+
123
+ image_logs = []
124
+
125
+ lq = [Image.open(lq_example) for lq_example in args.validation_image]
126
+
127
+ pipe = StableDiffusionXLPipeline(
128
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
129
+ unet, scheduler, image_encoder, image_processor,
130
+ ).to(accelerator.device)
131
+
132
+ timesteps = [args.num_train_timesteps - 1]
133
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
134
+ latents = prepare_latents(lq, vae, scheduler, generator, timesteps[-1])
135
+ image = pipe(
136
+ prompt=[""]*len(lq),
137
+ ip_adapter_image=[lq],
138
+ num_inference_steps=1,
139
+ timesteps=timesteps,
140
+ generator=generator,
141
+ guidance_scale=1.0,
142
+ height=args.resolution,
143
+ width=args.resolution,
144
+ latents=latents,
145
+ ).images
146
+
147
+ if log_local:
148
+ # for i, img in enumerate(tensor_to_pil(lq_img)):
149
+ # img.save(f"./lq_{i}.png")
150
+ # for i, img in enumerate(tensor_to_pil(gt_img)):
151
+ # img.save(f"./gt_{i}.png")
152
+ for i, img in enumerate(image):
153
+ img.save(f"./lq_IPA_{i}.png")
154
+ return
155
+
156
+ tracker_key = "test" if is_final_validation else "validation"
157
+ for tracker in accelerator.trackers:
158
+ if tracker.name == "tensorboard":
159
+ images = [np.asarray(pil_img) for pil_img in image]
160
+ images = np.stack(images, axis=0)
161
+ if lq_img is not None and gt_img is not None:
162
+ input_lq = lq_img.detach().cpu()
163
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
164
+ input_gt = gt_img.detach().cpu()
165
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
166
+ tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
167
+ tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
168
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
169
+ elif tracker.name == "wandb":
170
+ raise NotImplementedError("Wandb logging not implemented for validation.")
171
+ formatted_images = []
172
+
173
+ for log in image_logs:
174
+ images = log["images"]
175
+ validation_prompt = log["validation_prompt"]
176
+ validation_image = log["validation_image"]
177
+
178
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
179
+
180
+ for image in images:
181
+ image = wandb.Image(image, caption=validation_prompt)
182
+ formatted_images.append(image)
183
+
184
+ tracker.log({tracker_key: formatted_images})
185
+ else:
186
+ logger.warning(f"image logging not implemented for {tracker.name}")
187
+
188
+ gc.collect()
189
+ torch.cuda.empty_cache()
190
+
191
+ return image_logs
192
+
193
+
194
+ class DDIMSolver:
195
+ def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
196
+ # DDIM sampling parameters
197
+ step_ratio = timesteps // ddim_timesteps
198
+
199
+ self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
200
+ self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
201
+ self.ddim_alpha_cumprods_prev = np.asarray(
202
+ [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
203
+ )
204
+ # convert to torch tensors
205
+ self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
206
+ self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
207
+ self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)
208
+
209
+ def to(self, device):
210
+ self.ddim_timesteps = self.ddim_timesteps.to(device)
211
+ self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
212
+ self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
213
+ return self
214
+
215
+ def ddim_step(self, pred_x0, pred_noise, timestep_index):
216
+ alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
217
+ dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
218
+ x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
219
+ return x_prev
220
+
221
+
222
+ def append_dims(x, target_dims):
223
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
224
+ dims_to_append = target_dims - x.ndim
225
+ if dims_to_append < 0:
226
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
227
+ return x[(...,) + (None,) * dims_to_append]
228
+
229
+
230
+ # From LCMScheduler.get_scalings_for_boundary_condition_discrete
231
+ def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
232
+ scaled_timestep = timestep_scaling * timestep
233
+ c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
234
+ c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
235
+ return c_skip, c_out
236
+
237
+
238
+ # Compare LCMScheduler.step, Step 4
239
+ def get_predicted_original_sample(model_output, timesteps, sample, prediction_type, alphas, sigmas):
240
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
241
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
242
+ if prediction_type == "epsilon":
243
+ pred_x_0 = (sample - sigmas * model_output) / alphas
244
+ elif prediction_type == "sample":
245
+ pred_x_0 = model_output
246
+ elif prediction_type == "v_prediction":
247
+ pred_x_0 = alphas * sample - sigmas * model_output
248
+ else:
249
+ raise ValueError(
250
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
251
+ f" are supported."
252
+ )
253
+
254
+ return pred_x_0
255
+
256
+
257
+ # Based on step 4 in DDIMScheduler.step
258
+ def get_predicted_noise(model_output, timesteps, sample, prediction_type, alphas, sigmas):
259
+ alphas = extract_into_tensor(alphas, timesteps, sample.shape)
260
+ sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
261
+ if prediction_type == "epsilon":
262
+ pred_epsilon = model_output
263
+ elif prediction_type == "sample":
264
+ pred_epsilon = (sample - alphas * model_output) / sigmas
265
+ elif prediction_type == "v_prediction":
266
+ pred_epsilon = alphas * model_output + sigmas * sample
267
+ else:
268
+ raise ValueError(
269
+ f"Prediction type {prediction_type} is not supported; currently, `epsilon`, `sample`, and `v_prediction`"
270
+ f" are supported."
271
+ )
272
+
273
+ return pred_epsilon
274
+
275
+
276
+ def extract_into_tensor(a, t, x_shape):
277
+ b, *_ = t.shape
278
+ out = a.gather(-1, t)
279
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
280
+
281
+
282
+ def parse_args():
283
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
284
+ # ----------Model Checkpoint Loading Arguments----------
285
+ parser.add_argument(
286
+ "--pretrained_model_name_or_path",
287
+ type=str,
288
+ default=None,
289
+ required=True,
290
+ help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
291
+ )
292
+ parser.add_argument(
293
+ "--pretrained_vae_model_name_or_path",
294
+ type=str,
295
+ default=None,
296
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
297
+ )
298
+ parser.add_argument(
299
+ "--teacher_revision",
300
+ type=str,
301
+ default=None,
302
+ required=False,
303
+ help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.",
304
+ )
305
+ parser.add_argument(
306
+ "--revision",
307
+ type=str,
308
+ default=None,
309
+ required=False,
310
+ help="Revision of pretrained LDM model identifier from huggingface.co/models.",
311
+ )
312
+ parser.add_argument(
313
+ "--pretrained_lcm_lora_path",
314
+ type=str,
315
+ default=None,
316
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
317
+ )
318
+ parser.add_argument(
319
+ "--feature_extractor_path",
320
+ type=str,
321
+ default=None,
322
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
323
+ )
324
+ parser.add_argument(
325
+ "--pretrained_adapter_model_path",
326
+ type=str,
327
+ default=None,
328
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
329
+ )
330
+ parser.add_argument(
331
+ "--adapter_tokens",
332
+ type=int,
333
+ default=64,
334
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
335
+ )
336
+ parser.add_argument(
337
+ "--use_clip_encoder",
338
+ action="store_true",
339
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
340
+ )
341
+ parser.add_argument(
342
+ "--image_encoder_hidden_feature",
343
+ action="store_true",
344
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
345
+ )
346
+ # ----------Training Arguments----------
347
+ # ----General Training Arguments----
348
+ parser.add_argument(
349
+ "--output_dir",
350
+ type=str,
351
+ default="lcm-xl-distilled",
352
+ help="The output directory where the model predictions and checkpoints will be written.",
353
+ )
354
+ parser.add_argument(
355
+ "--cache_dir",
356
+ type=str,
357
+ default=None,
358
+ help="The directory where the downloaded models and datasets will be stored.",
359
+ )
360
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
361
+ # ----Logging----
362
+ parser.add_argument(
363
+ "--logging_dir",
364
+ type=str,
365
+ default="logs",
366
+ help=(
367
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
368
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
369
+ ),
370
+ )
371
+ parser.add_argument(
372
+ "--report_to",
373
+ type=str,
374
+ default="tensorboard",
375
+ help=(
376
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
377
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
378
+ ),
379
+ )
380
+ # ----Checkpointing----
381
+ parser.add_argument(
382
+ "--checkpointing_steps",
383
+ type=int,
384
+ default=4000,
385
+ help=(
386
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
387
+ " training using `--resume_from_checkpoint`."
388
+ ),
389
+ )
390
+ parser.add_argument(
391
+ "--checkpoints_total_limit",
392
+ type=int,
393
+ default=5,
394
+ help=("Max number of checkpoints to store."),
395
+ )
396
+ parser.add_argument(
397
+ "--resume_from_checkpoint",
398
+ type=str,
399
+ default=None,
400
+ help=(
401
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
402
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
403
+ ),
404
+ )
405
+ parser.add_argument(
406
+ "--save_only_adapter",
407
+ action="store_true",
408
+ help="Only save extra adapter to save space.",
409
+ )
410
+ # ----Image Processing----
411
+ parser.add_argument(
412
+ "--data_config_path",
413
+ type=str,
414
+ default=None,
415
+ help=("A folder containing the training data. "),
416
+ )
417
+ parser.add_argument(
418
+ "--train_data_dir",
419
+ type=str,
420
+ default=None,
421
+ help=(
422
+ "A folder containing the training data. Folder contents must follow the structure described in"
423
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
424
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
425
+ ),
426
+ )
427
+ parser.add_argument(
428
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
429
+ )
430
+ parser.add_argument(
431
+ "--conditioning_image_column",
432
+ type=str,
433
+ default="conditioning_image",
434
+ help="The column of the dataset containing the controlnet conditioning image.",
435
+ )
436
+ parser.add_argument(
437
+ "--caption_column",
438
+ type=str,
439
+ default="text",
440
+ help="The column of the dataset containing a caption or a list of captions.",
441
+ )
442
+ parser.add_argument(
443
+ "--text_drop_rate",
444
+ type=float,
445
+ default=0,
446
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
447
+ )
448
+ parser.add_argument(
449
+ "--image_drop_rate",
450
+ type=float,
451
+ default=0,
452
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
453
+ )
454
+ parser.add_argument(
455
+ "--cond_drop_rate",
456
+ type=float,
457
+ default=0,
458
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
459
+ )
460
+ parser.add_argument(
461
+ "--resolution",
462
+ type=int,
463
+ default=1024,
464
+ help=(
465
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
466
+ " resolution"
467
+ ),
468
+ )
469
+ parser.add_argument(
470
+ "--interpolation_type",
471
+ type=str,
472
+ default="bilinear",
473
+ help=(
474
+ "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`,"
475
+ " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
476
+ ),
477
+ )
478
+ parser.add_argument(
479
+ "--center_crop",
480
+ default=False,
481
+ action="store_true",
482
+ help=(
483
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
484
+ " cropped. The images will be resized to the resolution first before cropping."
485
+ ),
486
+ )
487
+ parser.add_argument(
488
+ "--random_flip",
489
+ action="store_true",
490
+ help="whether to randomly flip images horizontally",
491
+ )
492
+ parser.add_argument(
493
+ "--encode_batch_size",
494
+ type=int,
495
+ default=8,
496
+ help="Batch size to use for VAE encoding of the images for efficient processing.",
497
+ )
498
+ # ----Dataloader----
499
+ parser.add_argument(
500
+ "--dataloader_num_workers",
501
+ type=int,
502
+ default=0,
503
+ help=(
504
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
505
+ ),
506
+ )
507
+ # ----Batch Size and Training Steps----
508
+ parser.add_argument(
509
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
510
+ )
511
+ parser.add_argument("--num_train_epochs", type=int, default=100)
512
+ parser.add_argument(
513
+ "--max_train_steps",
514
+ type=int,
515
+ default=None,
516
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
517
+ )
518
+ parser.add_argument(
519
+ "--max_train_samples",
520
+ type=int,
521
+ default=None,
522
+ help=(
523
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
524
+ "value if set."
525
+ ),
526
+ )
527
+ # ----Learning Rate----
528
+ parser.add_argument(
529
+ "--learning_rate",
530
+ type=float,
531
+ default=1e-6,
532
+ help="Initial learning rate (after the potential warmup period) to use.",
533
+ )
534
+ parser.add_argument(
535
+ "--scale_lr",
536
+ action="store_true",
537
+ default=False,
538
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
539
+ )
540
+ parser.add_argument(
541
+ "--lr_scheduler",
542
+ type=str,
543
+ default="constant",
544
+ help=(
545
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
546
+ ' "constant", "constant_with_warmup"]'
547
+ ),
548
+ )
549
+ parser.add_argument(
550
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
551
+ )
552
+ parser.add_argument(
553
+ "--lr_num_cycles",
554
+ type=int,
555
+ default=1,
556
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
557
+ )
558
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
559
+ parser.add_argument(
560
+ "--gradient_accumulation_steps",
561
+ type=int,
562
+ default=1,
563
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
564
+ )
565
+ # ----Optimizer (Adam)----
566
+ parser.add_argument(
567
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
568
+ )
569
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
570
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
571
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
572
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
573
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
574
+ # ----Diffusion Training Arguments----
575
+ # ----Latent Consistency Distillation (LCD) Specific Arguments----
576
+ parser.add_argument(
577
+ "--w_min",
578
+ type=float,
579
+ default=3.0,
580
+ required=False,
581
+ help=(
582
+ "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
583
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
584
+ " compared to the original paper."
585
+ ),
586
+ )
587
+ parser.add_argument(
588
+ "--w_max",
589
+ type=float,
590
+ default=15.0,
591
+ required=False,
592
+ help=(
593
+ "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG"
594
+ " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as"
595
+ " compared to the original paper."
596
+ ),
597
+ )
598
+ parser.add_argument(
599
+ "--num_train_timesteps",
600
+ type=int,
601
+ default=1000,
602
+ help="The number of timesteps to use for DDIM sampling.",
603
+ )
604
+ parser.add_argument(
605
+ "--num_ddim_timesteps",
606
+ type=int,
607
+ default=50,
608
+ help="The number of timesteps to use for DDIM sampling.",
609
+ )
610
+ parser.add_argument(
611
+ "--losses_config_path",
612
+ type=str,
613
+ default='config_files/losses.yaml',
614
+ required=True,
615
+ help=("A yaml file containing losses to use and their weights."),
616
+ )
617
+ parser.add_argument(
618
+ "--loss_type",
619
+ type=str,
620
+ default="l2",
621
+ choices=["l2", "huber"],
622
+ help="The type of loss to use for the LCD loss.",
623
+ )
624
+ parser.add_argument(
625
+ "--huber_c",
626
+ type=float,
627
+ default=0.001,
628
+ help="The huber loss parameter. Only used if `--loss_type=huber`.",
629
+ )
630
+ parser.add_argument(
631
+ "--lora_rank",
632
+ type=int,
633
+ default=64,
634
+ help="The rank of the LoRA projection matrix.",
635
+ )
636
+ parser.add_argument(
637
+ "--lora_alpha",
638
+ type=int,
639
+ default=64,
640
+ help=(
641
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
642
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
643
+ ),
644
+ )
645
+ parser.add_argument(
646
+ "--lora_dropout",
647
+ type=float,
648
+ default=0.0,
649
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
650
+ )
651
+ parser.add_argument(
652
+ "--lora_target_modules",
653
+ type=str,
654
+ default=None,
655
+ help=(
656
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
657
+ " be used. By default, LoRA will be applied to all conv and linear layers."
658
+ ),
659
+ )
660
+ parser.add_argument(
661
+ "--vae_encode_batch_size",
662
+ type=int,
663
+ default=8,
664
+ required=False,
665
+ help=(
666
+ "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE."
667
+ " Encoding or decoding the whole batch at once may run into OOM issues."
668
+ ),
669
+ )
670
+ parser.add_argument(
671
+ "--timestep_scaling_factor",
672
+ type=float,
673
+ default=10.0,
674
+ help=(
675
+ "The multiplicative timestep scaling factor used when calculating the boundary scalings for LCM. The"
676
+ " higher the scaling is, the lower the approximation error, but the default value of 10.0 should typically"
677
+ " suffice."
678
+ ),
679
+ )
680
+ # ----Mixed Precision----
681
+ parser.add_argument(
682
+ "--mixed_precision",
683
+ type=str,
684
+ default=None,
685
+ choices=["no", "fp16", "bf16"],
686
+ help=(
687
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
688
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
689
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
690
+ ),
691
+ )
692
+ parser.add_argument(
693
+ "--allow_tf32",
694
+ action="store_true",
695
+ help=(
696
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
697
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
698
+ ),
699
+ )
700
+ # ----Training Optimizations----
701
+ parser.add_argument(
702
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
703
+ )
704
+ parser.add_argument(
705
+ "--gradient_checkpointing",
706
+ action="store_true",
707
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
708
+ )
709
+ # ----Distributed Training----
710
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
711
+ # ----------Validation Arguments----------
712
+ parser.add_argument(
713
+ "--validation_steps",
714
+ type=int,
715
+ default=3000,
716
+ help="Run validation every X steps.",
717
+ )
718
+ parser.add_argument(
719
+ "--validation_image",
720
+ type=str,
721
+ default=None,
722
+ nargs="+",
723
+ help=(
724
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
725
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
726
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
727
+ " `--validation_image` that will be used with all `--validation_prompt`s."
728
+ ),
729
+ )
730
+ parser.add_argument(
731
+ "--validation_prompt",
732
+ type=str,
733
+ default=None,
734
+ nargs="+",
735
+ help=(
736
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
737
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
738
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
739
+ ),
740
+ )
741
+ parser.add_argument(
742
+ "--sanity_check",
743
+ action="store_true",
744
+ help=(
745
+ "sanity check"
746
+ ),
747
+ )
748
+ # ----------Huggingface Hub Arguments-----------
749
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
750
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
751
+ parser.add_argument(
752
+ "--hub_model_id",
753
+ type=str,
754
+ default=None,
755
+ help="The name of the repository to keep in sync with the local `output_dir`.",
756
+ )
757
+ # ----------Accelerate Arguments----------
758
+ parser.add_argument(
759
+ "--tracker_project_name",
760
+ type=str,
761
+ default="trian",
762
+ help=(
763
+ "The `project_name` argument passed to Accelerator.init_trackers for"
764
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
765
+ ),
766
+ )
767
+
768
+ args = parser.parse_args()
769
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
770
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
771
+ args.local_rank = env_local_rank
772
+
773
+ return args
774
+
775
+
776
+ def main(args):
777
+ if args.report_to == "wandb" and args.hub_token is not None:
778
+ raise ValueError(
779
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
780
+ " Please use `huggingface-cli login` to authenticate with the Hub."
781
+ )
782
+
783
+ logging_dir = Path(args.output_dir, args.logging_dir)
784
+
785
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
786
+
787
+ accelerator = Accelerator(
788
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
789
+ mixed_precision=args.mixed_precision,
790
+ log_with=args.report_to,
791
+ project_config=accelerator_project_config,
792
+ )
793
+
794
+ # Make one log on every process with the configuration for debugging.
795
+ logging.basicConfig(
796
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
797
+ datefmt="%m/%d/%Y %H:%M:%S",
798
+ level=logging.INFO,
799
+ )
800
+ logger.info(accelerator.state, main_process_only=False)
801
+ if accelerator.is_local_main_process:
802
+ transformers.utils.logging.set_verbosity_warning()
803
+ diffusers.utils.logging.set_verbosity_info()
804
+ else:
805
+ transformers.utils.logging.set_verbosity_error()
806
+ diffusers.utils.logging.set_verbosity_error()
807
+
808
+ # If passed along, set the training seed now.
809
+ if args.seed is not None:
810
+ set_seed(args.seed)
811
+
812
+ # Handle the repository creation.
813
+ if accelerator.is_main_process:
814
+ if args.output_dir is not None:
815
+ os.makedirs(args.output_dir, exist_ok=True)
816
+
817
+ # 1. Create the noise scheduler and the desired noise schedule.
818
+ noise_scheduler = DDPMScheduler.from_pretrained(
819
+ args.pretrained_model_name_or_path, subfolder="scheduler", revision=args.teacher_revision
820
+ )
821
+ noise_scheduler.config.num_train_timesteps = args.num_train_timesteps
822
+ lcm_scheduler = LCMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
823
+
824
+ # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
825
+ alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
826
+ sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
827
+ # Initialize the DDIM ODE solver for distillation.
828
+ solver = DDIMSolver(
829
+ noise_scheduler.alphas_cumprod.numpy(),
830
+ timesteps=noise_scheduler.config.num_train_timesteps,
831
+ ddim_timesteps=args.num_ddim_timesteps,
832
+ )
833
+
834
+ # 2. Load tokenizers from SDXL checkpoint.
835
+ tokenizer_one = AutoTokenizer.from_pretrained(
836
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False
837
+ )
838
+ tokenizer_two = AutoTokenizer.from_pretrained(
839
+ args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False
840
+ )
841
+
842
+ # 3. Load text encoders from SDXL checkpoint.
843
+ # import correct text encoder classes
844
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
845
+ args.pretrained_model_name_or_path, args.teacher_revision
846
+ )
847
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
848
+ args.pretrained_model_name_or_path, args.teacher_revision, subfolder="text_encoder_2"
849
+ )
850
+
851
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
852
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.teacher_revision
853
+ )
854
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
855
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.teacher_revision
856
+ )
857
+
858
+ if args.use_clip_encoder:
859
+ image_processor = CLIPImageProcessor()
860
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
861
+ else:
862
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
863
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
864
+
865
+ # 4. Load VAE from SDXL checkpoint (or more stable VAE)
866
+ vae_path = (
867
+ args.pretrained_model_name_or_path
868
+ if args.pretrained_vae_model_name_or_path is None
869
+ else args.pretrained_vae_model_name_or_path
870
+ )
871
+ vae = AutoencoderKL.from_pretrained(
872
+ vae_path,
873
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
874
+ revision=args.teacher_revision,
875
+ )
876
+
877
+ # 7. Create online student U-Net.
878
+ unet = UNet2DConditionModel.from_pretrained(
879
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.teacher_revision
880
+ )
881
+
882
+ # Resampler for project model in IP-Adapter
883
+ image_proj_model = Resampler(
884
+ dim=1280,
885
+ depth=4,
886
+ dim_head=64,
887
+ heads=20,
888
+ num_queries=args.adapter_tokens,
889
+ embedding_dim=image_encoder.config.hidden_size,
890
+ output_dim=unet.config.cross_attention_dim,
891
+ ff_mult=4
892
+ )
893
+
894
+ # Load the same adapter in both unet.
895
+ init_adapter_in_unet(
896
+ unet,
897
+ image_proj_model,
898
+ os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
899
+ adapter_tokens=args.adapter_tokens,
900
+ )
901
+
902
+ # Check that all trainable models are in full precision
903
+ low_precision_error_string = (
904
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
905
+ " doing mixed precision training, copy of the weights should still be float32."
906
+ )
907
+
908
+ def unwrap_model(model):
909
+ model = accelerator.unwrap_model(model)
910
+ model = model._orig_mod if is_compiled_module(model) else model
911
+ return model
912
+
913
+ if unwrap_model(unet).dtype != torch.float32:
914
+ raise ValueError(
915
+ f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}"
916
+ )
917
+
918
+ if args.pretrained_lcm_lora_path is not None:
919
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
920
+ unet_state_dict = {
921
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
922
+ }
923
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
924
+ lora_state_dict = dict()
925
+ for k, v in unet_state_dict.items():
926
+ if "ip" in k:
927
+ k = k.replace("attn2", "attn2.processor")
928
+ lora_state_dict[k] = v
929
+ else:
930
+ lora_state_dict[k] = v
931
+ if alpha_dict:
932
+ args.lora_alpha = next(iter(alpha_dict.values()))
933
+ else:
934
+ args.lora_alpha = 1
935
+ # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer.
936
+ if args.lora_target_modules is not None:
937
+ lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")]
938
+ else:
939
+ lora_target_modules = [
940
+ "to_q",
941
+ "to_kv",
942
+ "0.to_out",
943
+ "attn1.to_k",
944
+ "attn1.to_v",
945
+ "to_k_ip",
946
+ "to_v_ip",
947
+ "ln_k_ip.linear",
948
+ "ln_v_ip.linear",
949
+ "to_out.0",
950
+ "proj_in",
951
+ "proj_out",
952
+ "ff.net.0.proj",
953
+ "ff.net.2",
954
+ "conv1",
955
+ "conv2",
956
+ "conv_shortcut",
957
+ "downsamplers.0.conv",
958
+ "upsamplers.0.conv",
959
+ "time_emb_proj",
960
+ ]
961
+ lora_config = LoraConfig(
962
+ r=args.lora_rank,
963
+ target_modules=lora_target_modules,
964
+ lora_alpha=args.lora_alpha,
965
+ lora_dropout=args.lora_dropout,
966
+ )
967
+
968
+ # Legacy
969
+ # for k, v in lcm_pipe.unet.state_dict().items():
970
+ # if "lora" in k or "base_layer" in k:
971
+ # lcm_dict[k.replace("default_0", "default")] = v
972
+
973
+ unet.add_adapter(lora_config)
974
+ if args.pretrained_lcm_lora_path is not None:
975
+ incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
976
+ if incompatible_keys is not None:
977
+ # check only for unexpected keys
978
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
979
+ if unexpected_keys:
980
+ logger.warning(
981
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
982
+ f" {unexpected_keys}. "
983
+ )
984
+
985
+ # 6. Freeze unet, vae, text_encoders.
986
+ vae.requires_grad_(False)
987
+ text_encoder_one.requires_grad_(False)
988
+ text_encoder_two.requires_grad_(False)
989
+ image_encoder.requires_grad_(False)
990
+ unet.requires_grad_(False)
991
+
992
+ # 10. Handle saving and loading of checkpoints
993
+ # `accelerate` 0.16.0 will have better support for customized saving
994
+ if args.save_only_adapter:
995
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
996
+ def save_model_hook(models, weights, output_dir):
997
+ if accelerator.is_main_process:
998
+ for model in models:
999
+ if isinstance(model, type(unwrap_model(unet))): # save adapter only
1000
+ unet_ = unwrap_model(model)
1001
+ # also save the checkpoints in native `diffusers` format so that it can be easily
1002
+ # be independently loaded via `load_lora_weights()`.
1003
+ state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
1004
+ StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict, safe_serialization=False)
1005
+
1006
+ weights.pop()
1007
+
1008
+ def load_model_hook(models, input_dir):
1009
+
1010
+ while len(models) > 0:
1011
+ # pop models so that they are not loaded again
1012
+ model = models.pop()
1013
+
1014
+ if isinstance(model, type(unwrap_model(unet))):
1015
+ unet_ = unwrap_model(model)
1016
+ lora_state_dict, _ = StableDiffusionXLPipeline.lora_state_dict(input_dir)
1017
+ unet_state_dict = {
1018
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
1019
+ }
1020
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
1021
+ lora_state_dict = dict()
1022
+ for k, v in unet_state_dict.items():
1023
+ if "ip" in k:
1024
+ k = k.replace("attn2", "attn2.processor")
1025
+ lora_state_dict[k] = v
1026
+ else:
1027
+ lora_state_dict[k] = v
1028
+ incompatible_keys = set_peft_model_state_dict(unet_, lora_state_dict, adapter_name="default")
1029
+ if incompatible_keys is not None:
1030
+ # check only for unexpected keys
1031
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
1032
+ if unexpected_keys:
1033
+ logger.warning(
1034
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1035
+ f" {unexpected_keys}. "
1036
+ )
1037
+
1038
+ accelerator.register_save_state_pre_hook(save_model_hook)
1039
+ accelerator.register_load_state_pre_hook(load_model_hook)
1040
+
1041
+ # 11. Enable optimizations
1042
+ if args.enable_xformers_memory_efficient_attention:
1043
+ if is_xformers_available():
1044
+ import xformers
1045
+
1046
+ xformers_version = version.parse(xformers.__version__)
1047
+ if xformers_version == version.parse("0.0.16"):
1048
+ logger.warning(
1049
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1050
+ )
1051
+ unet.enable_xformers_memory_efficient_attention()
1052
+ else:
1053
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
1054
+
1055
+ # Enable TF32 for faster training on Ampere GPUs,
1056
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1057
+ if args.allow_tf32:
1058
+ torch.backends.cuda.matmul.allow_tf32 = True
1059
+
1060
+ if args.gradient_checkpointing:
1061
+ unet.enable_gradient_checkpointing()
1062
+ vae.enable_gradient_checkpointing()
1063
+
1064
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1065
+ if args.use_8bit_adam:
1066
+ try:
1067
+ import bitsandbytes as bnb
1068
+ except ImportError:
1069
+ raise ImportError(
1070
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1071
+ )
1072
+
1073
+ optimizer_class = bnb.optim.AdamW8bit
1074
+ else:
1075
+ optimizer_class = torch.optim.AdamW
1076
+
1077
+ # 12. Optimizer creation
1078
+ lora_params, non_lora_params = seperate_lora_params_from_unet(unet)
1079
+ params_to_optimize = lora_params
1080
+ optimizer = optimizer_class(
1081
+ params_to_optimize,
1082
+ lr=args.learning_rate,
1083
+ betas=(args.adam_beta1, args.adam_beta2),
1084
+ weight_decay=args.adam_weight_decay,
1085
+ eps=args.adam_epsilon,
1086
+ )
1087
+
1088
+ # 13. Dataset creation and data processing
1089
+ # In distributed training, the load_dataset function guarantees that only one local process can concurrently
1090
+ # download the dataset.
1091
+ datasets = []
1092
+ datasets_name = []
1093
+ datasets_weights = []
1094
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
1095
+ if args.data_config_path is not None:
1096
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
1097
+ for single_dataset in data_config.datasets:
1098
+ datasets_weights.append(single_dataset.dataset_weight)
1099
+ datasets_name.append(single_dataset.dataset_folder)
1100
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
1101
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
1102
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
1103
+ datasets.append(image_dataset)
1104
+ # TODO: Validation dataset
1105
+ if data_config.val_dataset is not None:
1106
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
1107
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
1108
+
1109
+ # Mix training datasets.
1110
+ sampler_train = None
1111
+ if len(datasets) == 1:
1112
+ train_dataset = datasets[0]
1113
+ else:
1114
+ # Weighted each dataset
1115
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
1116
+ dataset_weights = []
1117
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
1118
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
1119
+ sampler_train = torch.utils.data.WeightedRandomSampler(
1120
+ weights=dataset_weights,
1121
+ num_samples=len(dataset_weights)
1122
+ )
1123
+
1124
+ # DataLoaders creation:
1125
+ train_dataloader = torch.utils.data.DataLoader(
1126
+ train_dataset,
1127
+ sampler=sampler_train,
1128
+ shuffle=True if sampler_train is None else False,
1129
+ collate_fn=collate_fn,
1130
+ batch_size=args.train_batch_size,
1131
+ num_workers=args.dataloader_num_workers,
1132
+ )
1133
+
1134
+ # 14. Embeddings for the UNet.
1135
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1136
+ def compute_embeddings(prompt_batch, original_sizes, crop_coords, text_encoders, tokenizers, is_train=True):
1137
+ def compute_time_ids(original_size, crops_coords_top_left):
1138
+ target_size = (args.resolution, args.resolution)
1139
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1140
+ add_time_ids = torch.tensor([add_time_ids])
1141
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1142
+ return add_time_ids
1143
+
1144
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(prompt_batch, text_encoders, tokenizers, is_train)
1145
+ add_text_embeds = pooled_prompt_embeds
1146
+
1147
+ add_time_ids = torch.cat([compute_time_ids(s, c) for s, c in zip(original_sizes, crop_coords)])
1148
+
1149
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1150
+ add_text_embeds = add_text_embeds.to(accelerator.device)
1151
+ unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1152
+
1153
+ return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs}
1154
+
1155
+ text_encoders = [text_encoder_one, text_encoder_two]
1156
+ tokenizers = [tokenizer_one, tokenizer_two]
1157
+
1158
+ compute_embeddings_fn = functools.partial(compute_embeddings, text_encoders=text_encoders, tokenizers=tokenizers)
1159
+
1160
+ # Move pixels into latents.
1161
+ @torch.no_grad()
1162
+ def convert_to_latent(pixels):
1163
+ model_input = vae.encode(pixels).latent_dist.sample()
1164
+ model_input = model_input * vae.config.scaling_factor
1165
+ if args.pretrained_vae_model_name_or_path is None:
1166
+ model_input = model_input.to(weight_dtype)
1167
+ return model_input
1168
+
1169
+ # 15. LR Scheduler creation
1170
+ # Scheduler and math around the number of training steps.
1171
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
1172
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
1173
+ if args.max_train_steps is None:
1174
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
1175
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
1176
+ num_training_steps_for_scheduler = (
1177
+ args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1178
+ )
1179
+ else:
1180
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
1181
+
1182
+ if args.scale_lr:
1183
+ args.learning_rate = (
1184
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1185
+ )
1186
+
1187
+ # Make sure the trainable params are in float32.
1188
+ if args.mixed_precision == "fp16":
1189
+ # only upcast trainable parameters (LoRA) into fp32
1190
+ cast_training_params(unet, dtype=torch.float32)
1191
+
1192
+ lr_scheduler = get_scheduler(
1193
+ args.lr_scheduler,
1194
+ optimizer=optimizer,
1195
+ num_warmup_steps=num_warmup_steps_for_scheduler,
1196
+ num_training_steps=num_training_steps_for_scheduler,
1197
+ )
1198
+
1199
+ # 16. Prepare for training
1200
+ # Prepare everything with our `accelerator`.
1201
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1202
+ unet, optimizer, train_dataloader, lr_scheduler
1203
+ )
1204
+
1205
+ # 8. Handle mixed precision and device placement
1206
+ # For mixed precision training we cast all non-trainable weigths to half-precision
1207
+ # as these weights are only used for inference, keeping weights in full precision is not required.
1208
+ weight_dtype = torch.float32
1209
+ if accelerator.mixed_precision == "fp16":
1210
+ weight_dtype = torch.float16
1211
+ elif accelerator.mixed_precision == "bf16":
1212
+ weight_dtype = torch.bfloat16
1213
+
1214
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
1215
+ # The VAE is in float32 to avoid NaN losses.
1216
+ if args.pretrained_vae_model_name_or_path is None:
1217
+ vae.to(accelerator.device, dtype=torch.float32)
1218
+ else:
1219
+ vae.to(accelerator.device, dtype=weight_dtype)
1220
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1221
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1222
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
1223
+ for p in non_lora_params:
1224
+ p.data = p.data.to(dtype=weight_dtype)
1225
+ for p in lora_params:
1226
+ p.requires_grad_(True)
1227
+ unet.to(accelerator.device)
1228
+
1229
+ # Also move the alpha and sigma noise schedules to accelerator.device.
1230
+ alpha_schedule = alpha_schedule.to(accelerator.device)
1231
+ sigma_schedule = sigma_schedule.to(accelerator.device)
1232
+ solver = solver.to(accelerator.device)
1233
+
1234
+ # Instantiate Loss.
1235
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
1236
+ lcm_losses = list()
1237
+ for loss_config in losses_configs.lcm_losses:
1238
+ logger.info(f"Loading lcm loss: {loss_config.name}")
1239
+ loss = namedtuple("loss", ["loss", "weight"])
1240
+ loss_class = eval(loss_config.name)
1241
+ lcm_losses.append(loss(loss_class(
1242
+ visualize_every_k=loss_config.visualize_every_k,
1243
+ dtype=weight_dtype,
1244
+ accelerator=accelerator,
1245
+ dino_model=image_encoder,
1246
+ dino_preprocess=image_processor,
1247
+ huber_c=args.huber_c,
1248
+ **loss_config.init_params), weight=loss_config.weight))
1249
+
1250
+ # Final check.
1251
+ for n, p in unet.named_parameters():
1252
+ if p.requires_grad:
1253
+ assert "lora" in n, n
1254
+ assert p.dtype == torch.float32, n
1255
+ else:
1256
+ assert "lora" not in n, f"{n}"
1257
+ assert p.dtype == weight_dtype, n
1258
+ if args.sanity_check:
1259
+ if args.resume_from_checkpoint:
1260
+ if args.resume_from_checkpoint != "latest":
1261
+ path = os.path.basename(args.resume_from_checkpoint)
1262
+ else:
1263
+ # Get the most recent checkpoint
1264
+ dirs = os.listdir(args.output_dir)
1265
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1266
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1267
+ path = dirs[-1] if len(dirs) > 0 else None
1268
+
1269
+ if path is None:
1270
+ accelerator.print(
1271
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1272
+ )
1273
+ args.resume_from_checkpoint = None
1274
+ initial_global_step = 0
1275
+ else:
1276
+ accelerator.print(f"Resuming from checkpoint {path}")
1277
+ accelerator.load_state(os.path.join(args.output_dir, path))
1278
+
1279
+ # Check input data
1280
+ batch = next(iter(train_dataloader))
1281
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1282
+ out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
1283
+ lcm_scheduler, image_encoder, image_processor,
1284
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True)
1285
+ exit()
1286
+
1287
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1288
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1289
+ if args.max_train_steps is None:
1290
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1291
+ if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
1292
+ logger.warning(
1293
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
1294
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
1295
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
1296
+ )
1297
+ # Afterwards we recalculate our number of training epochs
1298
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1299
+
1300
+ # We need to initialize the trackers we use, and also store our configuration.
1301
+ # The trackers initializes automatically on the main process.
1302
+ if accelerator.is_main_process:
1303
+ tracker_config = dict(vars(args))
1304
+
1305
+ # tensorboard cannot handle list types for config
1306
+ tracker_config.pop("validation_prompt")
1307
+ tracker_config.pop("validation_image")
1308
+
1309
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1310
+
1311
+ # 17. Train!
1312
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1313
+
1314
+ logger.info("***** Running training *****")
1315
+ logger.info(f" Num examples = {len(train_dataset)}")
1316
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1317
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1318
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1319
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1320
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1321
+ global_step = 0
1322
+ first_epoch = 0
1323
+
1324
+ # Potentially load in the weights and states from a previous save
1325
+ if args.resume_from_checkpoint:
1326
+ if args.resume_from_checkpoint != "latest":
1327
+ path = os.path.basename(args.resume_from_checkpoint)
1328
+ else:
1329
+ # Get the most recent checkpoint
1330
+ dirs = os.listdir(args.output_dir)
1331
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1332
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1333
+ path = dirs[-1] if len(dirs) > 0 else None
1334
+
1335
+ if path is None:
1336
+ accelerator.print(
1337
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1338
+ )
1339
+ args.resume_from_checkpoint = None
1340
+ initial_global_step = 0
1341
+ else:
1342
+ accelerator.print(f"Resuming from checkpoint {path}")
1343
+ accelerator.load_state(os.path.join(args.output_dir, path))
1344
+ global_step = int(path.split("-")[1])
1345
+
1346
+ initial_global_step = global_step
1347
+ first_epoch = global_step // num_update_steps_per_epoch
1348
+ else:
1349
+ initial_global_step = 0
1350
+
1351
+ progress_bar = tqdm(
1352
+ range(0, args.max_train_steps),
1353
+ initial=initial_global_step,
1354
+ desc="Steps",
1355
+ # Only show the progress bar once on each machine.
1356
+ disable=not accelerator.is_local_main_process,
1357
+ )
1358
+
1359
+ unet.train()
1360
+ for epoch in range(first_epoch, args.num_train_epochs):
1361
+ for step, batch in enumerate(train_dataloader):
1362
+ with accelerator.accumulate(unet):
1363
+ total_loss = torch.tensor(0.0)
1364
+ bsz = batch["images"].shape[0]
1365
+
1366
+ # Drop conditions.
1367
+ rand_tensor = torch.rand(bsz)
1368
+ drop_image_idx = rand_tensor < args.image_drop_rate
1369
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
1370
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
1371
+ drop_image_idx = drop_image_idx | drop_both_idx
1372
+ drop_text_idx = drop_text_idx | drop_both_idx
1373
+
1374
+ with torch.no_grad():
1375
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1376
+ lq_pt = image_processor(
1377
+ images=lq_img*0.5+0.5,
1378
+ do_rescale=False, return_tensors="pt"
1379
+ ).pixel_values
1380
+ image_embeds = prepare_training_image_embeds(
1381
+ image_encoder, image_processor,
1382
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1383
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
1384
+ idx_to_replace=drop_image_idx
1385
+ )
1386
+ uncond_image_embeds = prepare_training_image_embeds(
1387
+ image_encoder, image_processor,
1388
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1389
+ device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
1390
+ idx_to_replace=torch.ones_like(drop_image_idx)
1391
+ )
1392
+ # 1. Load and process the image and text conditioning
1393
+ text, orig_size, crop_coords = (
1394
+ batch["text"],
1395
+ batch["original_sizes"],
1396
+ batch["crop_top_lefts"],
1397
+ )
1398
+
1399
+ encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
1400
+ uncond_encoded_text = compute_embeddings_fn([""]*len(text), orig_size, crop_coords)
1401
+
1402
+ # encode pixel values with batch size of at most args.vae_encode_batch_size
1403
+ gt_img = gt_img.to(dtype=vae.dtype)
1404
+ latents = []
1405
+ for i in range(0, gt_img.shape[0], args.vae_encode_batch_size):
1406
+ latents.append(vae.encode(gt_img[i : i + args.vae_encode_batch_size]).latent_dist.sample())
1407
+ latents = torch.cat(latents, dim=0)
1408
+ # latents = convert_to_latent(gt_img)
1409
+
1410
+ latents = latents * vae.config.scaling_factor
1411
+ if args.pretrained_vae_model_name_or_path is None:
1412
+ latents = latents.to(weight_dtype)
1413
+
1414
+ # 2. Sample a random timestep for each image t_n from the ODE solver timesteps without bias.
1415
+ # For the DDIM solver, the timestep schedule is [T - 1, T - k - 1, T - 2 * k - 1, ...]
1416
+ bsz = latents.shape[0]
1417
+ topk = noise_scheduler.config.num_train_timesteps // args.num_ddim_timesteps
1418
+ index = torch.randint(0, args.num_ddim_timesteps, (bsz,), device=latents.device).long()
1419
+ start_timesteps = solver.ddim_timesteps[index]
1420
+ timesteps = start_timesteps - topk
1421
+ timesteps = torch.where(timesteps < 0, torch.zeros_like(timesteps), timesteps)
1422
+
1423
+ # 3. Get boundary scalings for start_timesteps and (end) timesteps.
1424
+ c_skip_start, c_out_start = scalings_for_boundary_conditions(
1425
+ start_timesteps, timestep_scaling=args.timestep_scaling_factor
1426
+ )
1427
+ c_skip_start, c_out_start = [append_dims(x, latents.ndim) for x in [c_skip_start, c_out_start]]
1428
+ c_skip, c_out = scalings_for_boundary_conditions(
1429
+ timesteps, timestep_scaling=args.timestep_scaling_factor
1430
+ )
1431
+ c_skip, c_out = [append_dims(x, latents.ndim) for x in [c_skip, c_out]]
1432
+
1433
+ # 4. Sample noise from the prior and add it to the latents according to the noise magnitude at each
1434
+ # timestep (this is the forward diffusion process) [z_{t_{n + k}} in Algorithm 1]
1435
+ noise = torch.randn_like(latents)
1436
+ noisy_model_input = noise_scheduler.add_noise(latents, noise, start_timesteps)
1437
+
1438
+ # 5. Sample a random guidance scale w from U[w_min, w_max]
1439
+ # Note that for LCM-LoRA distillation it is not necessary to use a guidance scale embedding
1440
+ w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
1441
+ w = w.reshape(bsz, 1, 1, 1)
1442
+ w = w.to(device=latents.device, dtype=latents.dtype)
1443
+
1444
+ # 6. Prepare prompt embeds and unet_added_conditions
1445
+ prompt_embeds = encoded_text.pop("prompt_embeds")
1446
+ encoded_text["image_embeds"] = image_embeds
1447
+ uncond_prompt_embeds = uncond_encoded_text.pop("prompt_embeds")
1448
+ uncond_encoded_text["image_embeds"] = image_embeds
1449
+
1450
+ # 7. Get online LCM prediction on z_{t_{n + k}} (noisy_model_input), w, c, t_{n + k} (start_timesteps)
1451
+ noise_pred = unet(
1452
+ noisy_model_input,
1453
+ start_timesteps,
1454
+ encoder_hidden_states=uncond_prompt_embeds,
1455
+ added_cond_kwargs=uncond_encoded_text,
1456
+ ).sample
1457
+ pred_x_0 = get_predicted_original_sample(
1458
+ noise_pred,
1459
+ start_timesteps,
1460
+ noisy_model_input,
1461
+ noise_scheduler.config.prediction_type,
1462
+ alpha_schedule,
1463
+ sigma_schedule,
1464
+ )
1465
+ model_pred = c_skip_start * noisy_model_input + c_out_start * pred_x_0
1466
+
1467
+ # 8. Compute the conditional and unconditional teacher model predictions to get CFG estimates of the
1468
+ # predicted noise eps_0 and predicted original sample x_0, then run the ODE solver using these
1469
+ # estimates to predict the data point in the augmented PF-ODE trajectory corresponding to the next ODE
1470
+ # solver timestep.
1471
+
1472
+ # With the adapters disabled, the `unet` is the regular teacher model.
1473
+ accelerator.unwrap_model(unet).disable_adapters()
1474
+ with torch.no_grad():
1475
+
1476
+ # 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
1477
+ teacher_added_cond = dict()
1478
+ for k,v in encoded_text.items():
1479
+ if isinstance(v, torch.Tensor):
1480
+ teacher_added_cond[k] = v.to(weight_dtype)
1481
+ else:
1482
+ teacher_image_embeds = []
1483
+ for img_emb in v:
1484
+ teacher_image_embeds.append(img_emb.to(weight_dtype))
1485
+ teacher_added_cond[k] = teacher_image_embeds
1486
+ cond_teacher_output = unet(
1487
+ noisy_model_input,
1488
+ start_timesteps,
1489
+ encoder_hidden_states=prompt_embeds,
1490
+ added_cond_kwargs=teacher_added_cond,
1491
+ ).sample
1492
+ cond_pred_x0 = get_predicted_original_sample(
1493
+ cond_teacher_output,
1494
+ start_timesteps,
1495
+ noisy_model_input,
1496
+ noise_scheduler.config.prediction_type,
1497
+ alpha_schedule,
1498
+ sigma_schedule,
1499
+ )
1500
+ cond_pred_noise = get_predicted_noise(
1501
+ cond_teacher_output,
1502
+ start_timesteps,
1503
+ noisy_model_input,
1504
+ noise_scheduler.config.prediction_type,
1505
+ alpha_schedule,
1506
+ sigma_schedule,
1507
+ )
1508
+
1509
+ # 2. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and unconditional embedding 0
1510
+ teacher_added_uncond = dict()
1511
+ uncond_encoded_text["image_embeds"] = uncond_image_embeds
1512
+ for k,v in uncond_encoded_text.items():
1513
+ if isinstance(v, torch.Tensor):
1514
+ teacher_added_uncond[k] = v.to(weight_dtype)
1515
+ else:
1516
+ teacher_uncond_image_embeds = []
1517
+ for img_emb in v:
1518
+ teacher_uncond_image_embeds.append(img_emb.to(weight_dtype))
1519
+ teacher_added_uncond[k] = teacher_uncond_image_embeds
1520
+ uncond_teacher_output = unet(
1521
+ noisy_model_input,
1522
+ start_timesteps,
1523
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
1524
+ added_cond_kwargs=teacher_added_uncond,
1525
+ ).sample
1526
+ uncond_pred_x0 = get_predicted_original_sample(
1527
+ uncond_teacher_output,
1528
+ start_timesteps,
1529
+ noisy_model_input,
1530
+ noise_scheduler.config.prediction_type,
1531
+ alpha_schedule,
1532
+ sigma_schedule,
1533
+ )
1534
+ uncond_pred_noise = get_predicted_noise(
1535
+ uncond_teacher_output,
1536
+ start_timesteps,
1537
+ noisy_model_input,
1538
+ noise_scheduler.config.prediction_type,
1539
+ alpha_schedule,
1540
+ sigma_schedule,
1541
+ )
1542
+
1543
+ # 3. Calculate the CFG estimate of x_0 (pred_x0) and eps_0 (pred_noise)
1544
+ # Note that this uses the LCM paper's CFG formulation rather than the Imagen CFG formulation
1545
+ pred_x0 = cond_pred_x0 + w * (cond_pred_x0 - uncond_pred_x0)
1546
+ pred_noise = cond_pred_noise + w * (cond_pred_noise - uncond_pred_noise)
1547
+ # 4. Run one step of the ODE solver to estimate the next point x_prev on the
1548
+ # augmented PF-ODE trajectory (solving backward in time)
1549
+ # Note that the DDIM step depends on both the predicted x_0 and source noise eps_0.
1550
+ x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(weight_dtype)
1551
+
1552
+ # re-enable unet adapters to turn the `unet` into a student unet.
1553
+ accelerator.unwrap_model(unet).enable_adapters()
1554
+
1555
+ # 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
1556
+ # Note that we do not use a separate target network for LCM-LoRA distillation.
1557
+ with torch.no_grad():
1558
+ uncond_encoded_text["image_embeds"] = image_embeds
1559
+ target_added_cond = dict()
1560
+ for k,v in uncond_encoded_text.items():
1561
+ if isinstance(v, torch.Tensor):
1562
+ target_added_cond[k] = v.to(weight_dtype)
1563
+ else:
1564
+ target_image_embeds = []
1565
+ for img_emb in v:
1566
+ target_image_embeds.append(img_emb.to(weight_dtype))
1567
+ target_added_cond[k] = target_image_embeds
1568
+ target_noise_pred = unet(
1569
+ x_prev,
1570
+ timesteps,
1571
+ encoder_hidden_states=uncond_prompt_embeds.to(weight_dtype),
1572
+ added_cond_kwargs=target_added_cond,
1573
+ ).sample
1574
+ pred_x_0 = get_predicted_original_sample(
1575
+ target_noise_pred,
1576
+ timesteps,
1577
+ x_prev,
1578
+ noise_scheduler.config.prediction_type,
1579
+ alpha_schedule,
1580
+ sigma_schedule,
1581
+ )
1582
+ target = c_skip * x_prev + c_out * pred_x_0
1583
+
1584
+ # 10. Calculate loss
1585
+ lcm_loss_arguments = {
1586
+ "target": target.float(),
1587
+ "predict": model_pred.float(),
1588
+ }
1589
+ loss_dict = dict()
1590
+ # total_loss = total_loss + torch.mean(
1591
+ # torch.sqrt((model_pred.float() - target.float()) ** 2 + args.huber_c**2) - args.huber_c
1592
+ # )
1593
+ # loss_dict["L2Loss"] = total_loss.item()
1594
+ for loss_config in lcm_losses:
1595
+ if loss_config.loss.__class__.__name__=="DINOLoss":
1596
+ with torch.no_grad():
1597
+ pixel_target = []
1598
+ latent_target = target.to(dtype=vae.dtype)
1599
+ for i in range(0, latent_target.shape[0], args.vae_encode_batch_size):
1600
+ pixel_target.append(
1601
+ vae.decode(
1602
+ latent_target[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
1603
+ return_dict=False
1604
+ )[0]
1605
+ )
1606
+ pixel_target = torch.cat(pixel_target, dim=0)
1607
+ pixel_pred = []
1608
+ latent_pred = model_pred.to(dtype=vae.dtype)
1609
+ for i in range(0, latent_pred.shape[0], args.vae_encode_batch_size):
1610
+ pixel_pred.append(
1611
+ vae.decode(
1612
+ latent_pred[i : i + args.vae_encode_batch_size] / vae.config.scaling_factor,
1613
+ return_dict=False
1614
+ )[0]
1615
+ )
1616
+ pixel_pred = torch.cat(pixel_pred, dim=0)
1617
+ dino_loss_arguments = {
1618
+ "target": pixel_target,
1619
+ "predict": pixel_pred,
1620
+ }
1621
+ non_weighted_loss = loss_config.loss(**dino_loss_arguments, accelerator=accelerator)
1622
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1623
+ total_loss = total_loss + non_weighted_loss * loss_config.weight
1624
+ else:
1625
+ non_weighted_loss = loss_config.loss(**lcm_loss_arguments, accelerator=accelerator)
1626
+ total_loss = total_loss + non_weighted_loss * loss_config.weight
1627
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1628
+
1629
+ # 11. Backpropagate on the online student model (`unet`) (only LoRA)
1630
+ accelerator.backward(total_loss)
1631
+ if accelerator.sync_gradients:
1632
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1633
+ optimizer.step()
1634
+ lr_scheduler.step()
1635
+ optimizer.zero_grad(set_to_none=True)
1636
+
1637
+ # Checks if the accelerator has performed an optimization step behind the scenes
1638
+ if accelerator.sync_gradients:
1639
+ progress_bar.update(1)
1640
+ global_step += 1
1641
+
1642
+ if accelerator.is_main_process:
1643
+ if global_step % args.checkpointing_steps == 0:
1644
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1645
+ if args.checkpoints_total_limit is not None:
1646
+ checkpoints = os.listdir(args.output_dir)
1647
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1648
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1649
+
1650
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1651
+ if len(checkpoints) >= args.checkpoints_total_limit:
1652
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1653
+ removing_checkpoints = checkpoints[0:num_to_remove]
1654
+
1655
+ logger.info(
1656
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1657
+ )
1658
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1659
+
1660
+ for removing_checkpoint in removing_checkpoints:
1661
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1662
+ shutil.rmtree(removing_checkpoint)
1663
+
1664
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1665
+ accelerator.save_state(save_path)
1666
+ logger.info(f"Saved state to {save_path}")
1667
+
1668
+ if global_step % args.validation_steps == 0:
1669
+ out_images = log_validation(unwrap_model(unet), vae, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two,
1670
+ lcm_scheduler, image_encoder, image_processor,
1671
+ args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False, log_local=False)
1672
+
1673
+ logs = dict()
1674
+ # logs.update({"loss": loss.detach().item()})
1675
+ logs.update(loss_dict)
1676
+ logs.update({"lr": lr_scheduler.get_last_lr()[0]})
1677
+ progress_bar.set_postfix(**logs)
1678
+ accelerator.log(logs, step=global_step)
1679
+
1680
+ if global_step >= args.max_train_steps:
1681
+ break
1682
+
1683
+ # Create the pipeline using using the trained modules and save it.
1684
+ accelerator.wait_for_everyone()
1685
+ if accelerator.is_main_process:
1686
+ unet = accelerator.unwrap_model(unet)
1687
+ unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
1688
+ StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)
1689
+
1690
+ if args.push_to_hub:
1691
+ upload_folder(
1692
+ repo_id=repo_id,
1693
+ folder_path=args.output_dir,
1694
+ commit_message="End of training",
1695
+ ignore_patterns=["step_*", "epoch_*"],
1696
+ )
1697
+
1698
+ del unet
1699
+ torch.cuda.empty_cache()
1700
+
1701
+ # Final inference.
1702
+ if args.validation_steps is not None:
1703
+ log_validation(unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1704
+ lcm_scheduler, image_encoder=None, image_processor=None,
1705
+ args=args, accelerator=accelerator, weight_dtype=weight_dtype, step=0, is_final_validation=False, log_local=True)
1706
+
1707
+ accelerator.end_training()
1708
+
1709
+
1710
+ if __name__ == "__main__":
1711
+ args = parse_args()
1712
+ main(args)
train_previewer_lora.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # After DCP training, distill the Previewer with DCP in `train_previewer_lora.py`:
2
+ accelerate launch --num_processes <num_of_gpus> train_previewer_lora.py \
3
+ --output_dir <your/output/path> \
4
+ --train_data_dir <your/data/path> \
5
+ --logging_dir <your/logging/path> \
6
+ --pretrained_model_name_or_path <your/sdxl/path> \
7
+ --feature_extractor_path <your/dinov2/path> \
8
+ --pretrained_adapter_model_path <your/dcp/path> \
9
+ --losses_config_path config_files/losses.yaml \
10
+ --data_config_path config_files/IR_dataset.yaml \
11
+ --save_only_adapter \
12
+ --gradient_checkpointing \
13
+ --num_train_timesteps 1000 \
14
+ --num_ddim_timesteps 50 \
15
+ --lora_alpha 1 \
16
+ --mixed_precision fp16 \
17
+ --train_batch_size 32 \
18
+ --vae_encode_batch_size 16 \
19
+ --gradient_accumulation_steps 1 \
20
+ --learning_rate 1e-4 \
21
+ --lr_warmup_steps 1000 \
22
+ --lr_scheduler cosine \
23
+ --lr_num_cycles 1 \
24
+ --resume_from_checkpoint latest
train_stage1_adapter.py ADDED
@@ -0,0 +1,1259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import contextlib
18
+ import time
19
+ import gc
20
+ import logging
21
+ import math
22
+ import os
23
+ import random
24
+ import jsonlines
25
+ import functools
26
+ import shutil
27
+ import pyrallis
28
+ import itertools
29
+ from pathlib import Path
30
+ from collections import namedtuple, OrderedDict
31
+
32
+ import accelerate
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn.functional as F
36
+ import torch.utils.checkpoint
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
41
+ from datasets import load_dataset
42
+ from packaging import version
43
+ from PIL import Image
44
+ from data.data_config import DataConfig
45
+ from basicsr.utils.degradation_pipeline import RealESRGANDegradation
46
+ from losses.loss_config import LossesConfig
47
+ from losses.losses import *
48
+ from torchvision import transforms
49
+ from torchvision.transforms.functional import crop
50
+ from tqdm.auto import tqdm
51
+ from transformers import (
52
+ AutoTokenizer,
53
+ PretrainedConfig,
54
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
55
+ AutoImageProcessor, AutoModel)
56
+
57
+ import diffusers
58
+ from diffusers import (
59
+ AutoencoderKL,
60
+ AutoencoderTiny,
61
+ DDPMScheduler,
62
+ StableDiffusionXLPipeline,
63
+ UNet2DConditionModel,
64
+ )
65
+ from diffusers.optimization import get_scheduler
66
+ from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
67
+ from diffusers.utils.import_utils import is_xformers_available
68
+ from diffusers.utils.torch_utils import is_compiled_module
69
+
70
+ from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
71
+ from utils.train_utils import (
72
+ seperate_ip_params_from_unet,
73
+ import_model_class_from_model_name_or_path,
74
+ tensor_to_pil,
75
+ get_train_dataset, prepare_train_dataset, collate_fn,
76
+ encode_prompt, importance_sampling_fn, extract_into_tensor
77
+ )
78
+ from module.ip_adapter.resampler import Resampler
79
+ from module.ip_adapter.attention_processor import init_attn_proc
80
+ from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
81
+
82
+
83
+ if is_wandb_available():
84
+ import wandb
85
+
86
+
87
+ logger = get_logger(__name__)
88
+
89
+
90
+ def log_validation(unet, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
91
+ scheduler, image_encoder, image_processor, deg_pipeline,
92
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
93
+ logger.info("Running validation... ")
94
+
95
+ image_logs = []
96
+
97
+ lq = [Image.open(lq_example) for lq_example in args.validation_image]
98
+
99
+ pipe = StableDiffusionXLPipeline(
100
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
101
+ unet, scheduler, image_encoder, image_processor,
102
+ ).to(accelerator.device)
103
+
104
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
105
+ image = pipe(
106
+ prompt=[""]*len(lq),
107
+ ip_adapter_image=[lq],
108
+ num_inference_steps=20,
109
+ generator=generator,
110
+ guidance_scale=5.0,
111
+ height=args.resolution,
112
+ width=args.resolution,
113
+ ).images
114
+
115
+ if log_local:
116
+ for i, img in enumerate(tensor_to_pil(lq_img)):
117
+ img.save(f"./lq_{i}.png")
118
+ for i, img in enumerate(tensor_to_pil(gt_img)):
119
+ img.save(f"./gt_{i}.png")
120
+ for i, img in enumerate(image):
121
+ img.save(f"./lq_IPA_{i}.png")
122
+ return
123
+
124
+ tracker_key = "test" if is_final_validation else "validation"
125
+ for tracker in accelerator.trackers:
126
+ if tracker.name == "tensorboard":
127
+ images = [np.asarray(pil_img) for pil_img in image]
128
+ images = np.stack(images, axis=0)
129
+ if lq_img is not None and gt_img is not None:
130
+ input_lq = lq_img.detach().cpu()
131
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
132
+ input_gt = gt_img.detach().cpu()
133
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
134
+ tracker.writer.add_images("lq", input_lq[0], step, dataformats="CHW")
135
+ tracker.writer.add_images("gt", input_gt[0], step, dataformats="CHW")
136
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
137
+ elif tracker.name == "wandb":
138
+ raise NotImplementedError("Wandb logging not implemented for validation.")
139
+ formatted_images = []
140
+
141
+ for log in image_logs:
142
+ images = log["images"]
143
+ validation_prompt = log["validation_prompt"]
144
+ validation_image = log["validation_image"]
145
+
146
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
147
+
148
+ for image in images:
149
+ image = wandb.Image(image, caption=validation_prompt)
150
+ formatted_images.append(image)
151
+
152
+ tracker.log({tracker_key: formatted_images})
153
+ else:
154
+ logger.warning(f"image logging not implemented for {tracker.name}")
155
+
156
+ gc.collect()
157
+ torch.cuda.empty_cache()
158
+
159
+ return image_logs
160
+
161
+
162
+ def parse_args(input_args=None):
163
+ parser = argparse.ArgumentParser(description="InstantIR stage-1 training.")
164
+ parser.add_argument(
165
+ "--pretrained_model_name_or_path",
166
+ type=str,
167
+ default=None,
168
+ required=True,
169
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
170
+ )
171
+ parser.add_argument(
172
+ "--pretrained_vae_model_name_or_path",
173
+ type=str,
174
+ default=None,
175
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
176
+ )
177
+ parser.add_argument(
178
+ "--feature_extractor_path",
179
+ type=str,
180
+ default=None,
181
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
182
+ )
183
+ parser.add_argument(
184
+ "--pretrained_adapter_model_path",
185
+ type=str,
186
+ default=None,
187
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
188
+ )
189
+ parser.add_argument(
190
+ "--adapter_tokens",
191
+ type=int,
192
+ default=64,
193
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
194
+ )
195
+ parser.add_argument(
196
+ "--use_clip_encoder",
197
+ action="store_true",
198
+ help="Whether or not to use DINO as image encoder, else CLIP encoder.",
199
+ )
200
+ parser.add_argument(
201
+ "--image_encoder_hidden_feature",
202
+ action="store_true",
203
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
204
+ )
205
+ parser.add_argument(
206
+ "--losses_config_path",
207
+ type=str,
208
+ required=True,
209
+ default='config_files/losses.yaml'
210
+ help=("A yaml file containing losses to use and their weights."),
211
+ )
212
+ parser.add_argument(
213
+ "--data_config_path",
214
+ type=str,
215
+ default='config_files/IR_dataset.yaml',
216
+ help=("A folder containing the training data. "),
217
+ )
218
+ parser.add_argument(
219
+ "--variant",
220
+ type=str,
221
+ default=None,
222
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
223
+ )
224
+ parser.add_argument(
225
+ "--revision",
226
+ type=str,
227
+ default=None,
228
+ required=False,
229
+ help="Revision of pretrained model identifier from huggingface.co/models.",
230
+ )
231
+ parser.add_argument(
232
+ "--tokenizer_name",
233
+ type=str,
234
+ default=None,
235
+ help="Pretrained tokenizer name or path if not the same as model_name",
236
+ )
237
+ parser.add_argument(
238
+ "--output_dir",
239
+ type=str,
240
+ default="stage1_model",
241
+ help="The output directory where the model predictions and checkpoints will be written.",
242
+ )
243
+ parser.add_argument(
244
+ "--cache_dir",
245
+ type=str,
246
+ default=None,
247
+ help="The directory where the downloaded models and datasets will be stored.",
248
+ )
249
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
250
+ parser.add_argument(
251
+ "--resolution",
252
+ type=int,
253
+ default=512,
254
+ help=(
255
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
256
+ " resolution"
257
+ ),
258
+ )
259
+ parser.add_argument(
260
+ "--crops_coords_top_left_h",
261
+ type=int,
262
+ default=0,
263
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
264
+ )
265
+ parser.add_argument(
266
+ "--crops_coords_top_left_w",
267
+ type=int,
268
+ default=0,
269
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
270
+ )
271
+ parser.add_argument(
272
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
273
+ )
274
+ parser.add_argument("--num_train_epochs", type=int, default=1)
275
+ parser.add_argument(
276
+ "--max_train_steps",
277
+ type=int,
278
+ default=None,
279
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
280
+ )
281
+ parser.add_argument(
282
+ "--checkpointing_steps",
283
+ type=int,
284
+ default=2000,
285
+ help=(
286
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
287
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
288
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
289
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
290
+ "instructions."
291
+ ),
292
+ )
293
+ parser.add_argument(
294
+ "--checkpoints_total_limit",
295
+ type=int,
296
+ default=5,
297
+ help=("Max number of checkpoints to store."),
298
+ )
299
+ parser.add_argument(
300
+ "--resume_from_checkpoint",
301
+ type=str,
302
+ default=None,
303
+ help=(
304
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
305
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
306
+ ),
307
+ )
308
+ parser.add_argument(
309
+ "--gradient_accumulation_steps",
310
+ type=int,
311
+ default=1,
312
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
313
+ )
314
+ parser.add_argument(
315
+ "--gradient_checkpointing",
316
+ action="store_true",
317
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
318
+ )
319
+ parser.add_argument(
320
+ "--save_only_adapter",
321
+ action="store_true",
322
+ help="Only save extra adapter to save space.",
323
+ )
324
+ parser.add_argument(
325
+ "--importance_sampling",
326
+ action="store_true",
327
+ help="Whether or not to use importance sampling.",
328
+ )
329
+ parser.add_argument(
330
+ "--learning_rate",
331
+ type=float,
332
+ default=1e-4,
333
+ help="Initial learning rate (after the potential warmup period) to use.",
334
+ )
335
+ parser.add_argument(
336
+ "--scale_lr",
337
+ action="store_true",
338
+ default=False,
339
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
340
+ )
341
+ parser.add_argument(
342
+ "--lr_scheduler",
343
+ type=str,
344
+ default="constant",
345
+ help=(
346
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
347
+ ' "constant", "constant_with_warmup"]'
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
352
+ )
353
+ parser.add_argument(
354
+ "--lr_num_cycles",
355
+ type=int,
356
+ default=1,
357
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
358
+ )
359
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
360
+ parser.add_argument(
361
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
362
+ )
363
+ parser.add_argument(
364
+ "--dataloader_num_workers",
365
+ type=int,
366
+ default=0,
367
+ help=(
368
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
369
+ ),
370
+ )
371
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
372
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
373
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
374
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
375
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
376
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
377
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
378
+ parser.add_argument(
379
+ "--hub_model_id",
380
+ type=str,
381
+ default=None,
382
+ help="The name of the repository to keep in sync with the local `output_dir`.",
383
+ )
384
+ parser.add_argument(
385
+ "--logging_dir",
386
+ type=str,
387
+ default="logs",
388
+ help=(
389
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
390
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
391
+ ),
392
+ )
393
+ parser.add_argument(
394
+ "--allow_tf32",
395
+ action="store_true",
396
+ help=(
397
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
398
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
399
+ ),
400
+ )
401
+ parser.add_argument(
402
+ "--report_to",
403
+ type=str,
404
+ default="tensorboard",
405
+ help=(
406
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
407
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
408
+ ),
409
+ )
410
+ parser.add_argument(
411
+ "--mixed_precision",
412
+ type=str,
413
+ default=None,
414
+ choices=["no", "fp16", "bf16"],
415
+ help=(
416
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
417
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
418
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
419
+ ),
420
+ )
421
+ parser.add_argument(
422
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
423
+ )
424
+ parser.add_argument(
425
+ "--set_grads_to_none",
426
+ action="store_true",
427
+ help=(
428
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
429
+ " behaviors, so disable this argument if it causes any problems. More info:"
430
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
431
+ ),
432
+ )
433
+ parser.add_argument(
434
+ "--dataset_name",
435
+ type=str,
436
+ default=None,
437
+ help=(
438
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
439
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
440
+ " or to a folder containing files that 🤗 Datasets can understand."
441
+ ),
442
+ )
443
+ parser.add_argument(
444
+ "--dataset_config_name",
445
+ type=str,
446
+ default=None,
447
+ help="The config of the Dataset, leave as None if there's only one config.",
448
+ )
449
+ parser.add_argument(
450
+ "--train_data_dir",
451
+ type=str,
452
+ default=None,
453
+ help=(
454
+ "A folder containing the training data. Folder contents must follow the structure described in"
455
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
456
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
457
+ ),
458
+ )
459
+ parser.add_argument(
460
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
461
+ )
462
+ parser.add_argument(
463
+ "--conditioning_image_column",
464
+ type=str,
465
+ default="conditioning_image",
466
+ help="The column of the dataset containing the controlnet conditioning image.",
467
+ )
468
+ parser.add_argument(
469
+ "--caption_column",
470
+ type=str,
471
+ default="text",
472
+ help="The column of the dataset containing a caption or a list of captions.",
473
+ )
474
+ parser.add_argument(
475
+ "--max_train_samples",
476
+ type=int,
477
+ default=None,
478
+ help=(
479
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
480
+ "value if set."
481
+ ),
482
+ )
483
+ parser.add_argument(
484
+ "--text_drop_rate",
485
+ type=float,
486
+ default=0.05,
487
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
488
+ )
489
+ parser.add_argument(
490
+ "--image_drop_rate",
491
+ type=float,
492
+ default=0.05,
493
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
494
+ )
495
+ parser.add_argument(
496
+ "--cond_drop_rate",
497
+ type=float,
498
+ default=0.05,
499
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
500
+ )
501
+ parser.add_argument(
502
+ "--sanity_check",
503
+ action="store_true",
504
+ help=(
505
+ "sanity check"
506
+ ),
507
+ )
508
+ parser.add_argument(
509
+ "--validation_prompt",
510
+ type=str,
511
+ default=None,
512
+ nargs="+",
513
+ help=(
514
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
515
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
516
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
517
+ ),
518
+ )
519
+ parser.add_argument(
520
+ "--validation_image",
521
+ type=str,
522
+ default=None,
523
+ nargs="+",
524
+ help=(
525
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
526
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
527
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
528
+ " `--validation_image` that will be used with all `--validation_prompt`s."
529
+ ),
530
+ )
531
+ parser.add_argument(
532
+ "--num_validation_images",
533
+ type=int,
534
+ default=4,
535
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
536
+ )
537
+ parser.add_argument(
538
+ "--validation_steps",
539
+ type=int,
540
+ default=3000,
541
+ help=(
542
+ "Run validation every X steps. Validation consists of running the prompt"
543
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
544
+ " and logging the images."
545
+ ),
546
+ )
547
+ parser.add_argument(
548
+ "--tracker_project_name",
549
+ type=str,
550
+ default="instantir_stage1",
551
+ help=(
552
+ "The `project_name` argument passed to Accelerator.init_trackers for"
553
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
554
+ ),
555
+ )
556
+
557
+ if input_args is not None:
558
+ args = parser.parse_args(input_args)
559
+ else:
560
+ args = parser.parse_args()
561
+
562
+ # if args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
563
+ # raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
564
+
565
+ if args.dataset_name is not None and args.train_data_dir is not None:
566
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
567
+
568
+ if args.text_drop_rate < 0 or args.text_drop_rate > 1:
569
+ raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
570
+
571
+ if args.validation_prompt is not None and args.validation_image is None:
572
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
573
+
574
+ if args.validation_prompt is None and args.validation_image is not None:
575
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
576
+
577
+ if (
578
+ args.validation_image is not None
579
+ and args.validation_prompt is not None
580
+ and len(args.validation_image) != 1
581
+ and len(args.validation_prompt) != 1
582
+ and len(args.validation_image) != len(args.validation_prompt)
583
+ ):
584
+ raise ValueError(
585
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
586
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
587
+ )
588
+
589
+ if args.resolution % 8 != 0:
590
+ raise ValueError(
591
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
592
+ )
593
+
594
+ return args
595
+
596
+
597
+ def main(args):
598
+ if args.report_to == "wandb" and args.hub_token is not None:
599
+ raise ValueError(
600
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
601
+ " Please use `huggingface-cli login` to authenticate with the Hub."
602
+ )
603
+
604
+ logging_dir = Path(args.output_dir, args.logging_dir)
605
+
606
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
607
+ # due to pytorch#99272, MPS does not yet support bfloat16.
608
+ raise ValueError(
609
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
610
+ )
611
+
612
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
613
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
614
+ accelerator = Accelerator(
615
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
616
+ mixed_precision=args.mixed_precision,
617
+ log_with=args.report_to,
618
+ project_config=accelerator_project_config,
619
+ # kwargs_handlers=[kwargs],
620
+ )
621
+
622
+ # Make one log on every process with the configuration for debugging.
623
+ logging.basicConfig(
624
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
625
+ datefmt="%m/%d/%Y %H:%M:%S",
626
+ level=logging.INFO,
627
+ )
628
+ logger.info(accelerator.state, main_process_only=False)
629
+ if accelerator.is_local_main_process:
630
+ transformers.utils.logging.set_verbosity_warning()
631
+ diffusers.utils.logging.set_verbosity_info()
632
+ else:
633
+ transformers.utils.logging.set_verbosity_error()
634
+ diffusers.utils.logging.set_verbosity_error()
635
+
636
+ # If passed along, set the training seed now.
637
+ if args.seed is not None:
638
+ set_seed(args.seed)
639
+
640
+ # Handle the repository creation.
641
+ if accelerator.is_main_process:
642
+ if args.output_dir is not None:
643
+ os.makedirs(args.output_dir, exist_ok=True)
644
+
645
+ # Load scheduler and models
646
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
647
+ # Importance sampling.
648
+ list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
649
+ prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
650
+ importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
651
+ importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
652
+
653
+ # Load the tokenizers
654
+ tokenizer = AutoTokenizer.from_pretrained(
655
+ args.pretrained_model_name_or_path,
656
+ subfolder="tokenizer",
657
+ revision=args.revision,
658
+ use_fast=False,
659
+ )
660
+ tokenizer_2 = AutoTokenizer.from_pretrained(
661
+ args.pretrained_model_name_or_path,
662
+ subfolder="tokenizer_2",
663
+ revision=args.revision,
664
+ use_fast=False,
665
+ )
666
+
667
+ # Text encoder and image encoder.
668
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
669
+ args.pretrained_model_name_or_path, args.revision
670
+ )
671
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
672
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
673
+ )
674
+ text_encoder = text_encoder_cls_one.from_pretrained(
675
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
676
+ )
677
+ text_encoder_2 = text_encoder_cls_two.from_pretrained(
678
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
679
+ )
680
+ if args.use_clip_encoder:
681
+ image_processor = CLIPImageProcessor()
682
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
683
+ else:
684
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
685
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
686
+
687
+ # VAE.
688
+ vae_path = (
689
+ args.pretrained_model_name_or_path
690
+ if args.pretrained_vae_model_name_or_path is None
691
+ else args.pretrained_vae_model_name_or_path
692
+ )
693
+ vae = AutoencoderKL.from_pretrained(
694
+ vae_path,
695
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
696
+ revision=args.revision,
697
+ variant=args.variant,
698
+ )
699
+
700
+ # UNet.
701
+ unet = UNet2DConditionModel.from_pretrained(
702
+ args.pretrained_model_name_or_path,
703
+ subfolder="unet",
704
+ revision=args.revision,
705
+ variant=args.variant
706
+ )
707
+
708
+ pipe = StableDiffusionXLPipeline.from_pretrained(
709
+ args.pretrained_model_name_or_path,
710
+ unet=unet,
711
+ text_encoder=text_encoder,
712
+ text_encoder_2=text_encoder_2,
713
+ vae=vae,
714
+ tokenizer=tokenizer,
715
+ tokenizer_2=tokenizer_2,
716
+ variant=args.variant
717
+ )
718
+
719
+ # Resampler for project model in IP-Adapter
720
+ image_proj_model = Resampler(
721
+ dim=1280,
722
+ depth=4,
723
+ dim_head=64,
724
+ heads=20,
725
+ num_queries=args.adapter_tokens,
726
+ embedding_dim=image_encoder.config.hidden_size,
727
+ output_dim=unet.config.cross_attention_dim,
728
+ ff_mult=4
729
+ )
730
+
731
+ init_adapter_in_unet(
732
+ unet,
733
+ image_proj_model,
734
+ os.path.join(args.pretrained_adapter_model_path, 'adapter_ckpt.pt'),
735
+ adapter_tokens=args.adapter_tokens,
736
+ )
737
+
738
+ # Initialize training state.
739
+ vae.requires_grad_(False)
740
+ text_encoder.requires_grad_(False)
741
+ text_encoder_2.requires_grad_(False)
742
+ unet.requires_grad_(False)
743
+ image_encoder.requires_grad_(False)
744
+
745
+ def unwrap_model(model):
746
+ model = accelerator.unwrap_model(model)
747
+ model = model._orig_mod if is_compiled_module(model) else model
748
+ return model
749
+
750
+ # `accelerate` 0.16.0 will have better support for customized saving
751
+ if args.save_only_adapter:
752
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
753
+ def save_model_hook(models, weights, output_dir):
754
+ if accelerator.is_main_process:
755
+ for model in models:
756
+ if isinstance(model, type(unwrap_model(unet))): # save adapter only
757
+ adapter_state_dict = OrderedDict()
758
+ adapter_state_dict["image_proj"] = model.encoder_hid_proj.image_projection_layers[0].state_dict()
759
+ adapter_state_dict["ip_adapter"] = torch.nn.ModuleList(model.attn_processors.values()).state_dict()
760
+ torch.save(adapter_state_dict, os.path.join(output_dir, "adapter_ckpt.pt"))
761
+
762
+ weights.pop()
763
+
764
+ def load_model_hook(models, input_dir):
765
+
766
+ while len(models) > 0:
767
+ # pop models so that they are not loaded again
768
+ model = models.pop()
769
+
770
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
771
+ adapter_state_dict = torch.load(os.path.join(input_dir, "adapter_ckpt.pt"), map_location="cpu")
772
+ if list(adapter_state_dict.keys()) != ["image_proj", "ip_adapter"]:
773
+ from module.ip_adapter.utils import revise_state_dict
774
+ adapter_state_dict = revise_state_dict(adapter_state_dict)
775
+ model.encoder_hid_proj.image_projection_layers[0].load_state_dict(adapter_state_dict["image_proj"], strict=True)
776
+ missing, unexpected = torch.nn.ModuleList(model.attn_processors.values()).load_state_dict(adapter_state_dict["ip_adapter"], strict=False)
777
+ if len(unexpected) > 0:
778
+ raise ValueError(f"Unexpected keys: {unexpected}")
779
+ if len(missing) > 0:
780
+ for mk in missing:
781
+ if "ln" not in mk:
782
+ raise ValueError(f"Missing keys: {missing}")
783
+
784
+ accelerator.register_save_state_pre_hook(save_model_hook)
785
+ accelerator.register_load_state_pre_hook(load_model_hook)
786
+
787
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
788
+ # as these models are only used for inference, keeping weights in full precision is not required.
789
+ weight_dtype = torch.float32
790
+ if accelerator.mixed_precision == "fp16":
791
+ weight_dtype = torch.float16
792
+ elif accelerator.mixed_precision == "bf16":
793
+ weight_dtype = torch.bfloat16
794
+
795
+ if args.enable_xformers_memory_efficient_attention:
796
+ if is_xformers_available():
797
+ import xformers
798
+
799
+ xformers_version = version.parse(xformers.__version__)
800
+ if xformers_version == version.parse("0.0.16"):
801
+ logger.warning(
802
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
803
+ )
804
+ unet.enable_xformers_memory_efficient_attention()
805
+ else:
806
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
807
+
808
+ if args.gradient_checkpointing:
809
+ unet.enable_gradient_checkpointing()
810
+ vae.enable_gradient_checkpointing()
811
+
812
+ # Enable TF32 for faster training on Ampere GPUs,
813
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
814
+ if args.allow_tf32:
815
+ torch.backends.cuda.matmul.allow_tf32 = True
816
+
817
+ if args.scale_lr:
818
+ args.learning_rate = (
819
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
820
+ )
821
+
822
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
823
+ if args.use_8bit_adam:
824
+ try:
825
+ import bitsandbytes as bnb
826
+ except ImportError:
827
+ raise ImportError(
828
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
829
+ )
830
+
831
+ optimizer_class = bnb.optim.AdamW8bit
832
+ else:
833
+ optimizer_class = torch.optim.AdamW
834
+
835
+ # Optimizer creation.
836
+ ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
837
+ params_to_optimize = ip_params
838
+ optimizer = optimizer_class(
839
+ params_to_optimize,
840
+ lr=args.learning_rate,
841
+ betas=(args.adam_beta1, args.adam_beta2),
842
+ weight_decay=args.adam_weight_decay,
843
+ eps=args.adam_epsilon,
844
+ )
845
+
846
+ # Instantiate Loss.
847
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
848
+ diffusion_losses = list()
849
+ for loss_config in losses_configs.diffusion_losses:
850
+ logger.info(f"Loading diffusion loss: {loss_config.name}")
851
+ loss = namedtuple("loss", ["loss", "weight"])
852
+ loss_class = eval(loss_config.name)
853
+ diffusion_losses.append(loss(loss_class(visualize_every_k=loss_config.visualize_every_k,
854
+ dtype=weight_dtype,
855
+ accelerator=accelerator,
856
+ **loss_config.init_params), weight=loss_config.weight))
857
+
858
+ # SDXL additional condition that will be added to time embedding.
859
+ def compute_time_ids(original_size, crops_coords_top_left):
860
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
861
+ target_size = (args.resolution, args.resolution)
862
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
863
+ add_time_ids = torch.tensor([add_time_ids])
864
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
865
+ return add_time_ids
866
+
867
+ # Text prompt embeddings.
868
+ @torch.no_grad()
869
+ def compute_embeddings(batch, text_encoders, tokenizers, drop_idx=None, is_train=True):
870
+ prompt_batch = batch[args.caption_column]
871
+ if drop_idx is not None:
872
+ for i in range(len(prompt_batch)):
873
+ prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
874
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
875
+ prompt_batch, text_encoders, tokenizers, is_train
876
+ )
877
+
878
+ add_time_ids = torch.cat(
879
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
880
+ )
881
+
882
+ prompt_embeds = prompt_embeds.to(accelerator.device)
883
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
884
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
885
+ sdxl_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
886
+
887
+ return prompt_embeds, sdxl_added_cond_kwargs
888
+
889
+ # Move pixels into latents.
890
+ @torch.no_grad()
891
+ def convert_to_latent(pixels):
892
+ model_input = vae.encode(pixels).latent_dist.sample()
893
+ model_input = model_input * vae.config.scaling_factor
894
+ if args.pretrained_vae_model_name_or_path is None:
895
+ model_input = model_input.to(weight_dtype)
896
+ return model_input
897
+
898
+ # Datasets and other data moduels.
899
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
900
+ compute_embeddings_fn = functools.partial(
901
+ compute_embeddings,
902
+ text_encoders=[text_encoder, text_encoder_2],
903
+ tokenizers=[tokenizer, tokenizer_2],
904
+ is_train=True,
905
+ )
906
+
907
+ datasets = []
908
+ datasets_name = []
909
+ datasets_weights = []
910
+ if args.data_config_path is not None:
911
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
912
+ for single_dataset in data_config.datasets:
913
+ datasets_weights.append(single_dataset.dataset_weight)
914
+ datasets_name.append(single_dataset.dataset_folder)
915
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
916
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
917
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
918
+ datasets.append(image_dataset)
919
+ # TODO: Validation dataset
920
+ if data_config.val_dataset is not None:
921
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
922
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
923
+
924
+ # Mix training datasets.
925
+ sampler_train = None
926
+ if len(datasets) == 1:
927
+ train_dataset = datasets[0]
928
+ else:
929
+ # Weighted each dataset
930
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
931
+ dataset_weights = []
932
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
933
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
934
+ sampler_train = torch.utils.data.WeightedRandomSampler(
935
+ weights=dataset_weights,
936
+ num_samples=len(dataset_weights)
937
+ )
938
+
939
+ train_dataloader = torch.utils.data.DataLoader(
940
+ train_dataset,
941
+ batch_size=args.train_batch_size,
942
+ sampler=sampler_train,
943
+ shuffle=True if sampler_train is None else False,
944
+ collate_fn=collate_fn,
945
+ num_workers=args.dataloader_num_workers
946
+ )
947
+
948
+ # We need to initialize the trackers we use, and also store our configuration.
949
+ # The trackers initializes automatically on the main process.
950
+ if accelerator.is_main_process:
951
+ tracker_config = dict(vars(args))
952
+
953
+ # tensorboard cannot handle list types for config
954
+ tracker_config.pop("validation_prompt")
955
+ tracker_config.pop("validation_image")
956
+
957
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
958
+
959
+ # Scheduler and math around the number of training steps.
960
+ overrode_max_train_steps = False
961
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
962
+ if args.max_train_steps is None:
963
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
964
+ overrode_max_train_steps = True
965
+
966
+ lr_scheduler = get_scheduler(
967
+ args.lr_scheduler,
968
+ optimizer=optimizer,
969
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
970
+ num_training_steps=args.max_train_steps,
971
+ num_cycles=args.lr_num_cycles,
972
+ power=args.lr_power,
973
+ )
974
+
975
+ # Prepare everything with our `accelerator`.
976
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
977
+ unet, optimizer, train_dataloader, lr_scheduler
978
+ )
979
+
980
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
981
+ if args.pretrained_vae_model_name_or_path is None:
982
+ # The VAE is fp32 to avoid NaN losses.
983
+ vae.to(accelerator.device, dtype=torch.float32)
984
+ else:
985
+ vae.to(accelerator.device, dtype=weight_dtype)
986
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
987
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
988
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
989
+ importance_ratio = importance_ratio.to(accelerator.device)
990
+ for non_ip_param in non_ip_params:
991
+ non_ip_param.data = non_ip_param.data.to(dtype=weight_dtype)
992
+ for ip_param in ip_params:
993
+ ip_param.requires_grad_(True)
994
+ unet.to(accelerator.device)
995
+
996
+ # Final check.
997
+ for n, p in unet.named_parameters():
998
+ if p.requires_grad: assert p.dtype == torch.float32, n
999
+ else: assert p.dtype == weight_dtype, n
1000
+ if args.sanity_check:
1001
+ if args.resume_from_checkpoint:
1002
+ if args.resume_from_checkpoint != "latest":
1003
+ path = os.path.basename(args.resume_from_checkpoint)
1004
+ else:
1005
+ # Get the most recent checkpoint
1006
+ dirs = os.listdir(args.output_dir)
1007
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1008
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1009
+ path = dirs[-1] if len(dirs) > 0 else None
1010
+
1011
+ if path is None:
1012
+ accelerator.print(
1013
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1014
+ )
1015
+ args.resume_from_checkpoint = None
1016
+ initial_global_step = 0
1017
+ else:
1018
+ accelerator.print(f"Resuming from checkpoint {path}")
1019
+ accelerator.load_state(os.path.join(args.output_dir, path))
1020
+
1021
+ # Check input data
1022
+ batch = next(iter(train_dataloader))
1023
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1024
+ images_log = log_validation(
1025
+ unwrap_model(unet), vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1026
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
1027
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, is_final_validation=False, log_local=True
1028
+ )
1029
+ exit()
1030
+
1031
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1032
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1033
+ if overrode_max_train_steps:
1034
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1035
+ # Afterwards we recalculate our number of training epochs
1036
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1037
+
1038
+ # Train!
1039
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1040
+
1041
+ logger.info("***** Running training *****")
1042
+ logger.info(f" Num examples = {len(train_dataset)}")
1043
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1044
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1045
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1046
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1047
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1048
+ logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
1049
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1050
+ global_step = 0
1051
+ first_epoch = 0
1052
+
1053
+ # Potentially load in the weights and states from a previous save
1054
+ if args.resume_from_checkpoint:
1055
+ if args.resume_from_checkpoint != "latest":
1056
+ path = os.path.basename(args.resume_from_checkpoint)
1057
+ else:
1058
+ # Get the most recent checkpoint
1059
+ dirs = os.listdir(args.output_dir)
1060
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1061
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1062
+ path = dirs[-1] if len(dirs) > 0 else None
1063
+
1064
+ if path is None:
1065
+ accelerator.print(
1066
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1067
+ )
1068
+ args.resume_from_checkpoint = None
1069
+ initial_global_step = 0
1070
+ else:
1071
+ accelerator.print(f"Resuming from checkpoint {path}")
1072
+ accelerator.load_state(os.path.join(args.output_dir, path))
1073
+ global_step = int(path.split("-")[1])
1074
+
1075
+ initial_global_step = global_step
1076
+ first_epoch = global_step // num_update_steps_per_epoch
1077
+ else:
1078
+ initial_global_step = 0
1079
+
1080
+ progress_bar = tqdm(
1081
+ range(0, args.max_train_steps),
1082
+ initial=initial_global_step,
1083
+ desc="Steps",
1084
+ # Only show the progress bar once on each machine.
1085
+ disable=not accelerator.is_local_main_process,
1086
+ )
1087
+
1088
+ trainable_models = [unet]
1089
+
1090
+ if args.gradient_checkpointing:
1091
+ checkpoint_models = []
1092
+ else:
1093
+ checkpoint_models = []
1094
+
1095
+ image_logs = None
1096
+ tic = time.time()
1097
+ for epoch in range(first_epoch, args.num_train_epochs):
1098
+ for step, batch in enumerate(train_dataloader):
1099
+ toc = time.time()
1100
+ io_time = toc - tic
1101
+ tic = toc
1102
+ for model in trainable_models + checkpoint_models:
1103
+ model.train()
1104
+ with accelerator.accumulate(*trainable_models):
1105
+ loss = torch.tensor(0.0)
1106
+
1107
+ # Drop conditions.
1108
+ rand_tensor = torch.rand(batch["images"].shape[0])
1109
+ drop_image_idx = rand_tensor < args.image_drop_rate
1110
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
1111
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
1112
+ drop_image_idx = drop_image_idx | drop_both_idx
1113
+ drop_text_idx = drop_text_idx | drop_both_idx
1114
+
1115
+ # Get LQ embeddings
1116
+ with torch.no_grad():
1117
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1118
+ lq_pt = image_processor(
1119
+ images=lq_img*0.5+0.5,
1120
+ do_rescale=False, return_tensors="pt"
1121
+ ).pixel_values
1122
+ image_embeds = prepare_training_image_embeds(
1123
+ image_encoder, image_processor,
1124
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1125
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
1126
+ idx_to_replace=drop_image_idx
1127
+ )
1128
+
1129
+ # Process text inputs.
1130
+ prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
1131
+ added_conditions["image_embeds"] = image_embeds
1132
+
1133
+ # Move inputs to latent space.
1134
+ gt_img = gt_img.to(dtype=vae.dtype)
1135
+ model_input = convert_to_latent(gt_img)
1136
+ if args.pretrained_vae_model_name_or_path is None:
1137
+ model_input = model_input.to(weight_dtype)
1138
+
1139
+ # Sample noise that we'll add to the latents.
1140
+ noise = torch.randn_like(model_input)
1141
+ bsz = model_input.shape[0]
1142
+
1143
+ # Sample a random timestep for each image.
1144
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
1145
+
1146
+ # Add noise to the model input according to the noise magnitude at each timestep
1147
+ # (this is the forward diffusion process)
1148
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1149
+ loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
1150
+
1151
+ toc = time.time()
1152
+ prepare_time = toc - tic
1153
+ tic = time.time()
1154
+
1155
+ model_pred = unet(
1156
+ noisy_model_input, timesteps,
1157
+ encoder_hidden_states=prompt_embeds_input,
1158
+ added_cond_kwargs=added_conditions,
1159
+ return_dict=False
1160
+ )[0]
1161
+
1162
+ diffusion_loss_arguments = {
1163
+ "target": noise,
1164
+ "predict": model_pred,
1165
+ "prompt_embeddings_input": prompt_embeds_input,
1166
+ "timesteps": timesteps,
1167
+ "weights": loss_weights,
1168
+ }
1169
+
1170
+ loss_dict = dict()
1171
+ for loss_config in diffusion_losses:
1172
+ non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
1173
+ loss = loss + non_weighted_loss * loss_config.weight
1174
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1175
+
1176
+ accelerator.backward(loss)
1177
+ if accelerator.sync_gradients:
1178
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1179
+ optimizer.step()
1180
+ lr_scheduler.step()
1181
+ optimizer.zero_grad()
1182
+
1183
+ toc = time.time()
1184
+ forward_time = toc - tic
1185
+ tic = toc
1186
+
1187
+ # Checks if the accelerator has performed an optimization step behind the scenes
1188
+ if accelerator.sync_gradients:
1189
+ progress_bar.update(1)
1190
+ global_step += 1
1191
+
1192
+ if accelerator.is_main_process:
1193
+ if global_step % args.checkpointing_steps == 0:
1194
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1195
+ if args.checkpoints_total_limit is not None:
1196
+ checkpoints = os.listdir(args.output_dir)
1197
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1198
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1199
+
1200
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1201
+ if len(checkpoints) >= args.checkpoints_total_limit:
1202
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1203
+ removing_checkpoints = checkpoints[0:num_to_remove]
1204
+
1205
+ logger.info(
1206
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1207
+ )
1208
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1209
+
1210
+ for removing_checkpoint in removing_checkpoints:
1211
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1212
+ shutil.rmtree(removing_checkpoint)
1213
+
1214
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1215
+ accelerator.save_state(save_path)
1216
+ logger.info(f"Saved state to {save_path}")
1217
+
1218
+ if global_step % args.validation_steps == 0:
1219
+ image_logs = log_validation(unwrap_model(unet), vae,
1220
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1221
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
1222
+ args, accelerator, weight_dtype, global_step, lq_img, gt_img, is_final_validation=False)
1223
+
1224
+ logs = {}
1225
+ logs.update(loss_dict)
1226
+ logs.update({
1227
+ "lr": lr_scheduler.get_last_lr()[0],
1228
+ "io_time": io_time,
1229
+ "prepare_time": prepare_time,
1230
+ "forward_time": forward_time
1231
+ })
1232
+ progress_bar.set_postfix(**logs)
1233
+ accelerator.log(logs, step=global_step)
1234
+ tic = time.time()
1235
+
1236
+ if global_step >= args.max_train_steps:
1237
+ break
1238
+
1239
+ # Create the pipeline using using the trained modules and save it.
1240
+ accelerator.wait_for_everyone()
1241
+ if accelerator.is_main_process:
1242
+ accelerator.save_state(os.path.join(args.output_dir, "last"), safe_serialization=False)
1243
+ # Run a final round of validation.
1244
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1245
+ image_logs = None
1246
+ if args.validation_image is not None:
1247
+ image_logs = log_validation(
1248
+ unwrap_model(unet), vae,
1249
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1250
+ noise_scheduler, image_encoder, image_processor, deg_pipeline,
1251
+ args, accelerator, weight_dtype, global_step,
1252
+ )
1253
+
1254
+ accelerator.end_training()
1255
+
1256
+
1257
+ if __name__ == "__main__":
1258
+ args = parse_args()
1259
+ main(args)
train_stage1_adapter.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 1: training lq adapter
2
+ accelerate launch --num_processes <num_of_gpus> train_stage1_adapter.py \
3
+ --output_dir <your/output/path> \
4
+ --train_data_dir <your/data/path> \
5
+ --logging_dir <your/logging/path> \
6
+ --pretrained_model_name_or_path <your/sdxl/path> \
7
+ --feature_extractor_path <your/dinov2/path> \
8
+ --save_only_adapter \
9
+ --gradient_checkpointing \
10
+ --mixed_precision fp16 \
11
+ --train_batch_size 96 \
12
+ --gradient_accumulation_steps 1 \
13
+ --learning_rate 1e-4 \
14
+ --lr_warmup_steps 1000 \
15
+ --lr_scheduler cosine \
16
+ --lr_num_cycles 1 \
17
+ --resume_from_checkpoint latest
train_stage2_aggregator.py ADDED
@@ -0,0 +1,1698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import os
17
+ import argparse
18
+ import time
19
+ import gc
20
+ import logging
21
+ import math
22
+ import copy
23
+ import random
24
+ import yaml
25
+ import functools
26
+ import shutil
27
+ import pyrallis
28
+ from pathlib import Path
29
+ from collections import namedtuple, OrderedDict
30
+
31
+ import accelerate
32
+ import numpy as np
33
+ import torch
34
+ from safetensors import safe_open
35
+ import torch.nn.functional as F
36
+ import torch.utils.checkpoint
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
41
+ from datasets import load_dataset
42
+ from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
43
+ from huggingface_hub import create_repo, upload_folder
44
+ from packaging import version
45
+ from PIL import Image
46
+ from data.data_config import DataConfig
47
+ from basicsr.utils.degradation_pipeline import RealESRGANDegradation
48
+ from losses.loss_config import LossesConfig
49
+ from losses.losses import *
50
+ from torchvision import transforms
51
+ from torchvision.transforms.functional import crop
52
+ from tqdm.auto import tqdm
53
+ from transformers import (
54
+ AutoTokenizer,
55
+ PretrainedConfig,
56
+ CLIPImageProcessor, CLIPVisionModelWithProjection,
57
+ AutoImageProcessor, AutoModel
58
+ )
59
+
60
+ import diffusers
61
+ from diffusers import (
62
+ AutoencoderKL,
63
+ DDPMScheduler,
64
+ StableDiffusionXLPipeline,
65
+ UNet2DConditionModel,
66
+ )
67
+ from diffusers.optimization import get_scheduler
68
+ from diffusers.utils import (
69
+ check_min_version,
70
+ convert_unet_state_dict_to_peft,
71
+ is_wandb_available,
72
+ )
73
+ from diffusers.utils.import_utils import is_xformers_available
74
+ from diffusers.utils.torch_utils import is_compiled_module
75
+
76
+ from module.aggregator import Aggregator
77
+ from schedulers.lcm_single_step_scheduler import LCMSingleStepScheduler
78
+ from module.ip_adapter.ip_adapter import MultiIPAdapterImageProjection
79
+ from module.ip_adapter.resampler import Resampler
80
+ from module.ip_adapter.utils import init_adapter_in_unet, prepare_training_image_embeds
81
+ from module.ip_adapter.attention_processor import init_attn_proc
82
+ from utils.train_utils import (
83
+ seperate_ip_params_from_unet,
84
+ import_model_class_from_model_name_or_path,
85
+ tensor_to_pil,
86
+ get_train_dataset, prepare_train_dataset, collate_fn,
87
+ encode_prompt, importance_sampling_fn, extract_into_tensor
88
+ )
89
+ from pipelines.sdxl_instantir import InstantIRPipeline
90
+
91
+
92
+ if is_wandb_available():
93
+ import wandb
94
+
95
+
96
+ logger = get_logger(__name__)
97
+
98
+
99
+ def log_validation(unet, aggregator, vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
100
+ scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
101
+ args, accelerator, weight_dtype, step, lq_img=None, gt_img=None, is_final_validation=False, log_local=False):
102
+ logger.info("Running validation... ")
103
+
104
+ image_logs = []
105
+
106
+ # validation_batch = batchify_pil(args.validation_image, args.validation_prompt, deg_pipeline, image_processor)
107
+ lq = [Image.open(lq_example).convert("RGB") for lq_example in args.validation_image]
108
+
109
+ pipe = InstantIRPipeline(
110
+ vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2,
111
+ unet, scheduler, aggregator, feature_extractor=image_processor, image_encoder=image_encoder,
112
+ ).to(accelerator.device)
113
+
114
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
115
+ if lq_img is not None and gt_img is not None:
116
+ lq_img = lq_img[:len(args.validation_image)]
117
+ lq_pt = image_processor(
118
+ images=lq_img*0.5+0.5,
119
+ do_rescale=False, return_tensors="pt"
120
+ ).pixel_values
121
+ image = pipe(
122
+ prompt=[""]*len(lq_img),
123
+ image=lq_img,
124
+ ip_adapter_image=lq_pt,
125
+ num_inference_steps=20,
126
+ generator=generator,
127
+ controlnet_conditioning_scale=1.0,
128
+ negative_prompt=[""]*len(lq),
129
+ guidance_scale=5.0,
130
+ height=args.resolution,
131
+ width=args.resolution,
132
+ lcm_scheduler=lcm_scheduler,
133
+ ).images
134
+ else:
135
+ image = pipe(
136
+ prompt=[""]*len(lq),
137
+ image=lq,
138
+ ip_adapter_image=lq,
139
+ num_inference_steps=20,
140
+ generator=generator,
141
+ controlnet_conditioning_scale=1.0,
142
+ negative_prompt=[""]*len(lq),
143
+ guidance_scale=5.0,
144
+ height=args.resolution,
145
+ width=args.resolution,
146
+ lcm_scheduler=lcm_scheduler,
147
+ ).images
148
+
149
+ if log_local:
150
+ for i, rec_image in enumerate(image):
151
+ rec_image.save(f"./instantid_{i}.png")
152
+ return
153
+
154
+ tracker_key = "test" if is_final_validation else "validation"
155
+ for tracker in accelerator.trackers:
156
+ if tracker.name == "tensorboard":
157
+ images = [np.asarray(pil_img) for pil_img in image]
158
+ images = np.stack(images, axis=0)
159
+ if lq_img is not None and gt_img is not None:
160
+ input_lq = lq_img.cpu()
161
+ input_lq = np.asarray(input_lq.add(1).div(2).clamp(0, 1))
162
+ input_gt = gt_img.cpu()
163
+ input_gt = np.asarray(input_gt.add(1).div(2).clamp(0, 1))
164
+ tracker.writer.add_images("lq", input_lq, step, dataformats="NCHW")
165
+ tracker.writer.add_images("gt", input_gt, step, dataformats="NCHW")
166
+ tracker.writer.add_images("rec", images, step, dataformats="NHWC")
167
+ elif tracker.name == "wandb":
168
+ raise NotImplementedError("Wandb logging not implemented for validation.")
169
+ formatted_images = []
170
+
171
+ for log in image_logs:
172
+ images = log["images"]
173
+ validation_prompt = log["validation_prompt"]
174
+ validation_image = log["validation_image"]
175
+
176
+ formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
177
+
178
+ for image in images:
179
+ image = wandb.Image(image, caption=validation_prompt)
180
+ formatted_images.append(image)
181
+
182
+ tracker.log({tracker_key: formatted_images})
183
+ else:
184
+ logger.warning(f"image logging not implemented for {tracker.name}")
185
+
186
+ gc.collect()
187
+ torch.cuda.empty_cache()
188
+
189
+ return image_logs
190
+
191
+
192
+ def remove_attn2(model):
193
+ def recursive_find_module(name, module):
194
+ if not "up_blocks" in name and not "down_blocks" in name and not "mid_block" in name: return
195
+ elif "resnets" in name: return
196
+ if hasattr(module, "attn2"):
197
+ setattr(module, "attn2", None)
198
+ setattr(module, "norm2", None)
199
+ return
200
+ for sub_name, sub_module in module.named_children():
201
+ recursive_find_module(f"{name}.{sub_name}", sub_module)
202
+
203
+ for name, module in model.named_children():
204
+ recursive_find_module(name, module)
205
+
206
+
207
+ def parse_args(input_args=None):
208
+ parser = argparse.ArgumentParser(description="Simple example of a IP-Adapter training script.")
209
+ parser.add_argument(
210
+ "--pretrained_model_name_or_path",
211
+ type=str,
212
+ default=None,
213
+ required=True,
214
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
215
+ )
216
+ parser.add_argument(
217
+ "--pretrained_vae_model_name_or_path",
218
+ type=str,
219
+ default=None,
220
+ help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.",
221
+ )
222
+ parser.add_argument(
223
+ "--controlnet_model_name_or_path",
224
+ type=str,
225
+ default=None,
226
+ help="Path to an pretrained controlnet model like tile-controlnet.",
227
+ )
228
+ parser.add_argument(
229
+ "--use_lcm",
230
+ action="store_true",
231
+ help="Whether or not to use lcm unet.",
232
+ )
233
+ parser.add_argument(
234
+ "--pretrained_lcm_lora_path",
235
+ type=str,
236
+ default=None,
237
+ help="Path to LCM lora or model identifier from huggingface.co/models.",
238
+ )
239
+ parser.add_argument(
240
+ "--lora_rank",
241
+ type=int,
242
+ default=64,
243
+ help="The rank of the LoRA projection matrix.",
244
+ )
245
+ parser.add_argument(
246
+ "--lora_alpha",
247
+ type=int,
248
+ default=64,
249
+ help=(
250
+ "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight"
251
+ " update delta_W. No scaling will be performed if this value is equal to `lora_rank`."
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--lora_dropout",
256
+ type=float,
257
+ default=0.0,
258
+ help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.",
259
+ )
260
+ parser.add_argument(
261
+ "--lora_target_modules",
262
+ type=str,
263
+ default=None,
264
+ help=(
265
+ "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will"
266
+ " be used. By default, LoRA will be applied to all conv and linear layers."
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--feature_extractor_path",
271
+ type=str,
272
+ default=None,
273
+ help="Path to image encoder for IP-Adapters or model identifier from huggingface.co/models.",
274
+ )
275
+ parser.add_argument(
276
+ "--pretrained_adapter_model_path",
277
+ type=str,
278
+ default=None,
279
+ help="Path to IP-Adapter models or model identifier from huggingface.co/models.",
280
+ )
281
+ parser.add_argument(
282
+ "--adapter_tokens",
283
+ type=int,
284
+ default=64,
285
+ help="Number of tokens to use in IP-adapter cross attention mechanism.",
286
+ )
287
+ parser.add_argument(
288
+ "--aggregator_adapter",
289
+ action="store_true",
290
+ help="Whether or not to add adapter on aggregator.",
291
+ )
292
+ parser.add_argument(
293
+ "--optimize_adapter",
294
+ action="store_true",
295
+ help="Whether or not to optimize IP-Adapter.",
296
+ )
297
+ parser.add_argument(
298
+ "--image_encoder_hidden_feature",
299
+ action="store_true",
300
+ help="Whether or not to use the penultimate hidden states as image embeddings.",
301
+ )
302
+ parser.add_argument(
303
+ "--losses_config_path",
304
+ type=str,
305
+ required=True,
306
+ help=("A yaml file containing losses to use and their weights."),
307
+ )
308
+ parser.add_argument(
309
+ "--data_config_path",
310
+ type=str,
311
+ default=None,
312
+ help=("A folder containing the training data. "),
313
+ )
314
+ parser.add_argument(
315
+ "--variant",
316
+ type=str,
317
+ default=None,
318
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
319
+ )
320
+ parser.add_argument(
321
+ "--revision",
322
+ type=str,
323
+ default=None,
324
+ required=False,
325
+ help="Revision of pretrained model identifier from huggingface.co/models.",
326
+ )
327
+ parser.add_argument(
328
+ "--tokenizer_name",
329
+ type=str,
330
+ default=None,
331
+ help="Pretrained tokenizer name or path if not the same as model_name",
332
+ )
333
+ parser.add_argument(
334
+ "--output_dir",
335
+ type=str,
336
+ default="stage1_model",
337
+ help="The output directory where the model predictions and checkpoints will be written.",
338
+ )
339
+ parser.add_argument(
340
+ "--cache_dir",
341
+ type=str,
342
+ default=None,
343
+ help="The directory where the downloaded models and datasets will be stored.",
344
+ )
345
+ parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
346
+ parser.add_argument(
347
+ "--resolution",
348
+ type=int,
349
+ default=512,
350
+ help=(
351
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
352
+ " resolution"
353
+ ),
354
+ )
355
+ parser.add_argument(
356
+ "--crops_coords_top_left_h",
357
+ type=int,
358
+ default=0,
359
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
360
+ )
361
+ parser.add_argument(
362
+ "--crops_coords_top_left_w",
363
+ type=int,
364
+ default=0,
365
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
366
+ )
367
+ parser.add_argument(
368
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
369
+ )
370
+ parser.add_argument("--num_train_epochs", type=int, default=1)
371
+ parser.add_argument(
372
+ "--max_train_steps",
373
+ type=int,
374
+ default=None,
375
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
376
+ )
377
+ parser.add_argument(
378
+ "--checkpointing_steps",
379
+ type=int,
380
+ default=3000,
381
+ help=(
382
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
383
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
384
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
385
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
386
+ "instructions."
387
+ ),
388
+ )
389
+ parser.add_argument(
390
+ "--checkpoints_total_limit",
391
+ type=int,
392
+ default=5,
393
+ help=("Max number of checkpoints to store."),
394
+ )
395
+ parser.add_argument(
396
+ "--previous_ckpt",
397
+ type=str,
398
+ default=None,
399
+ help=(
400
+ "Whether training should be initialized from a previous checkpoint."
401
+ ),
402
+ )
403
+ parser.add_argument(
404
+ "--resume_from_checkpoint",
405
+ type=str,
406
+ default=None,
407
+ help=(
408
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
409
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
410
+ ),
411
+ )
412
+ parser.add_argument(
413
+ "--gradient_accumulation_steps",
414
+ type=int,
415
+ default=1,
416
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
417
+ )
418
+ parser.add_argument(
419
+ "--gradient_checkpointing",
420
+ action="store_true",
421
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
422
+ )
423
+ parser.add_argument(
424
+ "--save_only_adapter",
425
+ action="store_true",
426
+ help="Only save extra adapter to save space.",
427
+ )
428
+ parser.add_argument(
429
+ "--cache_prompt_embeds",
430
+ action="store_true",
431
+ help="Whether or not to cache prompt embeds to save memory.",
432
+ )
433
+ parser.add_argument(
434
+ "--importance_sampling",
435
+ action="store_true",
436
+ help="Whether or not to use importance sampling.",
437
+ )
438
+ parser.add_argument(
439
+ "--CFG_scale",
440
+ type=float,
441
+ default=1.0,
442
+ help="CFG for previewer.",
443
+ )
444
+ parser.add_argument(
445
+ "--learning_rate",
446
+ type=float,
447
+ default=1e-4,
448
+ help="Initial learning rate (after the potential warmup period) to use.",
449
+ )
450
+ parser.add_argument(
451
+ "--scale_lr",
452
+ action="store_true",
453
+ default=False,
454
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
455
+ )
456
+ parser.add_argument(
457
+ "--lr_scheduler",
458
+ type=str,
459
+ default="constant",
460
+ help=(
461
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
462
+ ' "constant", "constant_with_warmup"]'
463
+ ),
464
+ )
465
+ parser.add_argument(
466
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
467
+ )
468
+ parser.add_argument(
469
+ "--lr_num_cycles",
470
+ type=int,
471
+ default=1,
472
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
473
+ )
474
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
475
+ parser.add_argument(
476
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
477
+ )
478
+ parser.add_argument(
479
+ "--dataloader_num_workers",
480
+ type=int,
481
+ default=0,
482
+ help=(
483
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
484
+ ),
485
+ )
486
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
487
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
488
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
489
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
490
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
491
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
492
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
493
+ parser.add_argument(
494
+ "--hub_model_id",
495
+ type=str,
496
+ default=None,
497
+ help="The name of the repository to keep in sync with the local `output_dir`.",
498
+ )
499
+ parser.add_argument(
500
+ "--logging_dir",
501
+ type=str,
502
+ default="logs",
503
+ help=(
504
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
505
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
506
+ ),
507
+ )
508
+ parser.add_argument(
509
+ "--allow_tf32",
510
+ action="store_true",
511
+ help=(
512
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
513
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
514
+ ),
515
+ )
516
+ parser.add_argument(
517
+ "--report_to",
518
+ type=str,
519
+ default="tensorboard",
520
+ help=(
521
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
522
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
523
+ ),
524
+ )
525
+ parser.add_argument(
526
+ "--mixed_precision",
527
+ type=str,
528
+ default=None,
529
+ choices=["no", "fp16", "bf16"],
530
+ help=(
531
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
532
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
533
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
534
+ ),
535
+ )
536
+ parser.add_argument(
537
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
538
+ )
539
+ parser.add_argument(
540
+ "--set_grads_to_none",
541
+ action="store_true",
542
+ help=(
543
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
544
+ " behaviors, so disable this argument if it causes any problems. More info:"
545
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
546
+ ),
547
+ )
548
+ parser.add_argument(
549
+ "--dataset_name",
550
+ type=str,
551
+ default=None,
552
+ help=(
553
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
554
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
555
+ " or to a folder containing files that 🤗 Datasets can understand."
556
+ ),
557
+ )
558
+ parser.add_argument(
559
+ "--dataset_config_name",
560
+ type=str,
561
+ default=None,
562
+ help="The config of the Dataset, leave as None if there's only one config.",
563
+ )
564
+ parser.add_argument(
565
+ "--train_data_dir",
566
+ type=str,
567
+ default=None,
568
+ help=(
569
+ "A folder containing the training data. Folder contents must follow the structure described in"
570
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
571
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
572
+ ),
573
+ )
574
+ parser.add_argument(
575
+ "--image_column", type=str, default="image", help="The column of the dataset containing the target image."
576
+ )
577
+ parser.add_argument(
578
+ "--conditioning_image_column",
579
+ type=str,
580
+ default="conditioning_image",
581
+ help="The column of the dataset containing the controlnet conditioning image.",
582
+ )
583
+ parser.add_argument(
584
+ "--caption_column",
585
+ type=str,
586
+ default="text",
587
+ help="The column of the dataset containing a caption or a list of captions.",
588
+ )
589
+ parser.add_argument(
590
+ "--max_train_samples",
591
+ type=int,
592
+ default=None,
593
+ help=(
594
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
595
+ "value if set."
596
+ ),
597
+ )
598
+ parser.add_argument(
599
+ "--text_drop_rate",
600
+ type=float,
601
+ default=0,
602
+ help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
603
+ )
604
+ parser.add_argument(
605
+ "--image_drop_rate",
606
+ type=float,
607
+ default=0,
608
+ help="Proportion of IP-Adapter inputs to be dropped. Defaults to 0 (no drop-out).",
609
+ )
610
+ parser.add_argument(
611
+ "--cond_drop_rate",
612
+ type=float,
613
+ default=0,
614
+ help="Proportion of all conditions to be dropped. Defaults to 0 (no drop-out).",
615
+ )
616
+ parser.add_argument(
617
+ "--use_ema_adapter",
618
+ action="store_true",
619
+ help=(
620
+ "use ema ip-adapter for LCM preview"
621
+ ),
622
+ )
623
+ parser.add_argument(
624
+ "--sanity_check",
625
+ action="store_true",
626
+ help=(
627
+ "sanity check"
628
+ ),
629
+ )
630
+ parser.add_argument(
631
+ "--validation_prompt",
632
+ type=str,
633
+ default=None,
634
+ nargs="+",
635
+ help=(
636
+ "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."
637
+ " Provide either a matching number of `--validation_image`s, a single `--validation_image`"
638
+ " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s."
639
+ ),
640
+ )
641
+ parser.add_argument(
642
+ "--validation_image",
643
+ type=str,
644
+ default=None,
645
+ nargs="+",
646
+ help=(
647
+ "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
648
+ " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
649
+ " a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
650
+ " `--validation_image` that will be used with all `--validation_prompt`s."
651
+ ),
652
+ )
653
+ parser.add_argument(
654
+ "--num_validation_images",
655
+ type=int,
656
+ default=4,
657
+ help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair",
658
+ )
659
+ parser.add_argument(
660
+ "--validation_steps",
661
+ type=int,
662
+ default=4000,
663
+ help=(
664
+ "Run validation every X steps. Validation consists of running the prompt"
665
+ " `args.validation_prompt` multiple times: `args.num_validation_images`"
666
+ " and logging the images."
667
+ ),
668
+ )
669
+ parser.add_argument(
670
+ "--tracker_project_name",
671
+ type=str,
672
+ default='train',
673
+ help=(
674
+ "The `project_name` argument passed to Accelerator.init_trackers for"
675
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
676
+ ),
677
+ )
678
+
679
+ if input_args is not None:
680
+ args = parser.parse_args(input_args)
681
+ else:
682
+ args = parser.parse_args()
683
+
684
+ if not args.sanity_check and args.dataset_name is None and args.train_data_dir is None and args.data_config_path is None:
685
+ raise ValueError("Specify either `--dataset_name` or `--train_data_dir`")
686
+
687
+ if args.dataset_name is not None and args.train_data_dir is not None:
688
+ raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`")
689
+
690
+ if args.text_drop_rate < 0 or args.text_drop_rate > 1:
691
+ raise ValueError("`--text_drop_rate` must be in the range [0, 1].")
692
+
693
+ if args.validation_prompt is not None and args.validation_image is None:
694
+ raise ValueError("`--validation_image` must be set if `--validation_prompt` is set")
695
+
696
+ if args.validation_prompt is None and args.validation_image is not None:
697
+ raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
698
+
699
+ if (
700
+ args.validation_image is not None
701
+ and args.validation_prompt is not None
702
+ and len(args.validation_image) != 1
703
+ and len(args.validation_prompt) != 1
704
+ and len(args.validation_image) != len(args.validation_prompt)
705
+ ):
706
+ raise ValueError(
707
+ "Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
708
+ " or the same number of `--validation_prompt`s and `--validation_image`s"
709
+ )
710
+
711
+ if args.resolution % 8 != 0:
712
+ raise ValueError(
713
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
714
+ )
715
+
716
+ return args
717
+
718
+
719
+ def update_ema_model(ema_model, model, ema_beta):
720
+ for ema_param, param in zip(ema_model.parameters(), model.parameters()):
721
+ ema_param.copy_(param.detach().lerp(ema_param, ema_beta))
722
+
723
+
724
+ def copy_dict(dict):
725
+ new_dict = {}
726
+ for key, value in dict.items():
727
+ new_dict[key] = value
728
+ return new_dict
729
+
730
+
731
+ def main(args):
732
+ if args.report_to == "wandb" and args.hub_token is not None:
733
+ raise ValueError(
734
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
735
+ " Please use `huggingface-cli login` to authenticate with the Hub."
736
+ )
737
+
738
+ logging_dir = Path(args.output_dir, args.logging_dir)
739
+
740
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
741
+ # due to pytorch#99272, MPS does not yet support bfloat16.
742
+ raise ValueError(
743
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
744
+ )
745
+
746
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
747
+
748
+ accelerator = Accelerator(
749
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
750
+ mixed_precision=args.mixed_precision,
751
+ log_with=args.report_to,
752
+ project_config=accelerator_project_config,
753
+ )
754
+
755
+ # Make one log on every process with the configuration for debugging.
756
+ logging.basicConfig(
757
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
758
+ datefmt="%m/%d/%Y %H:%M:%S",
759
+ level=logging.INFO,
760
+ )
761
+ logger.info(accelerator.state, main_process_only=False)
762
+ if accelerator.is_local_main_process:
763
+ transformers.utils.logging.set_verbosity_warning()
764
+ diffusers.utils.logging.set_verbosity_info()
765
+ else:
766
+ transformers.utils.logging.set_verbosity_error()
767
+ diffusers.utils.logging.set_verbosity_error()
768
+
769
+ # If passed along, set the training seed now.
770
+ if args.seed is not None:
771
+ set_seed(args.seed)
772
+
773
+ # Handle the repository creation.
774
+ if accelerator.is_main_process:
775
+ if args.output_dir is not None:
776
+ os.makedirs(args.output_dir, exist_ok=True)
777
+
778
+ # Load scheduler and models
779
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
780
+
781
+ # Importance sampling.
782
+ list_of_candidates = np.arange(noise_scheduler.config.num_train_timesteps, dtype='float64')
783
+ prob_dist = importance_sampling_fn(list_of_candidates, noise_scheduler.config.num_train_timesteps, 0.5)
784
+ importance_ratio = prob_dist / prob_dist.sum() * noise_scheduler.config.num_train_timesteps
785
+ importance_ratio = torch.from_numpy(importance_ratio.copy()).float()
786
+
787
+ # Load the tokenizers
788
+ tokenizer = AutoTokenizer.from_pretrained(
789
+ args.pretrained_model_name_or_path,
790
+ subfolder="tokenizer",
791
+ revision=args.revision,
792
+ use_fast=False,
793
+ )
794
+ tokenizer_2 = AutoTokenizer.from_pretrained(
795
+ args.pretrained_model_name_or_path,
796
+ subfolder="tokenizer_2",
797
+ revision=args.revision,
798
+ use_fast=False,
799
+ )
800
+
801
+ # Text encoder and image encoder.
802
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
803
+ args.pretrained_model_name_or_path, args.revision
804
+ )
805
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
806
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
807
+ )
808
+ text_encoder = text_encoder_cls_one.from_pretrained(
809
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
810
+ )
811
+ text_encoder_2 = text_encoder_cls_two.from_pretrained(
812
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
813
+ )
814
+
815
+ # Image processor and image encoder.
816
+ if args.use_clip_encoder:
817
+ image_processor = CLIPImageProcessor()
818
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.feature_extractor_path)
819
+ else:
820
+ image_processor = AutoImageProcessor.from_pretrained(args.feature_extractor_path)
821
+ image_encoder = AutoModel.from_pretrained(args.feature_extractor_path)
822
+
823
+ # VAE.
824
+ vae_path = (
825
+ args.pretrained_model_name_or_path
826
+ if args.pretrained_vae_model_name_or_path is None
827
+ else args.pretrained_vae_model_name_or_path
828
+ )
829
+ vae = AutoencoderKL.from_pretrained(
830
+ vae_path,
831
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
832
+ revision=args.revision,
833
+ variant=args.variant,
834
+ )
835
+
836
+ # UNet.
837
+ unet = UNet2DConditionModel.from_pretrained(
838
+ args.pretrained_model_name_or_path,
839
+ subfolder="unet",
840
+ revision=args.revision,
841
+ variant=args.variant
842
+ )
843
+
844
+ # Aggregator.
845
+ aggregator = Aggregator.from_unet(unet)
846
+ remove_attn2(aggregator)
847
+ if args.controlnet_model_name_or_path:
848
+ logger.info("Loading existing controlnet weights")
849
+ if args.controlnet_model_name_or_path.endswith(".safetensors"):
850
+ pretrained_cn_state_dict = {}
851
+ with safe_open(args.controlnet_model_name_or_path, framework="pt", device='cpu') as f:
852
+ for key in f.keys():
853
+ pretrained_cn_state_dict[key] = f.get_tensor(key)
854
+ else:
855
+ pretrained_cn_state_dict = torch.load(os.path.join(args.controlnet_model_name_or_path, "aggregator_ckpt.pt"), map_location="cpu")
856
+ aggregator.load_state_dict(pretrained_cn_state_dict, strict=True)
857
+ else:
858
+ logger.info("Initializing aggregator weights from unet.")
859
+
860
+ # Create image embedding projector for IP-Adapters.
861
+ if args.pretrained_adapter_model_path is not None:
862
+ if args.pretrained_adapter_model_path.endswith(".safetensors"):
863
+ pretrained_adapter_state_dict = {"image_proj": {}, "ip_adapter": {}}
864
+ with safe_open(args.pretrained_adapter_model_path, framework="pt", device="cpu") as f:
865
+ for key in f.keys():
866
+ if key.startswith("image_proj."):
867
+ pretrained_adapter_state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
868
+ elif key.startswith("ip_adapter."):
869
+ pretrained_adapter_state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
870
+ else:
871
+ pretrained_adapter_state_dict = torch.load(args.pretrained_adapter_model_path, map_location="cpu")
872
+
873
+ # Image embedding Projector.
874
+ image_proj_model = Resampler(
875
+ dim=1280,
876
+ depth=4,
877
+ dim_head=64,
878
+ heads=20,
879
+ num_queries=args.adapter_tokens,
880
+ embedding_dim=image_encoder.config.hidden_size,
881
+ output_dim=unet.config.cross_attention_dim,
882
+ ff_mult=4
883
+ )
884
+
885
+ init_adapter_in_unet(
886
+ unet,
887
+ image_proj_model,
888
+ pretrained_adapter_state_dict,
889
+ adapter_tokens=args.adapter_tokens,
890
+ )
891
+
892
+ # EMA adapter for LCM preview.
893
+ if args.use_ema_adapter:
894
+ assert args.optimize_adapter, "No need for EMA with frozen adapter."
895
+ ema_image_proj_model = Resampler(
896
+ dim=1280,
897
+ depth=4,
898
+ dim_head=64,
899
+ heads=20,
900
+ num_queries=args.adapter_tokens,
901
+ embedding_dim=image_encoder.config.hidden_size,
902
+ output_dim=unet.config.cross_attention_dim,
903
+ ff_mult=4
904
+ )
905
+ orig_encoder_hid_proj = unet.encoder_hid_proj
906
+ ema_encoder_hid_proj = MultiIPAdapterImageProjection([ema_image_proj_model])
907
+ orig_attn_procs = unet.attn_processors
908
+ orig_attn_procs_list = torch.nn.ModuleList(orig_attn_procs.values())
909
+ ema_attn_procs = init_attn_proc(unet, args.adapter_tokens, True, True, False)
910
+ ema_attn_procs_list = torch.nn.ModuleList(ema_attn_procs.values())
911
+ ema_attn_procs_list.requires_grad_(False)
912
+ ema_encoder_hid_proj.requires_grad_(False)
913
+
914
+ # Initialize EMA state.
915
+ ema_beta = 0.5 ** (args.ema_update_steps / max(args.ema_halflife_steps, 1e-8))
916
+ logger.info(f"Using EMA with beta: {ema_beta}")
917
+ ema_encoder_hid_proj.load_state_dict(orig_encoder_hid_proj.state_dict())
918
+ ema_attn_procs_list.load_state_dict(orig_attn_procs_list.state_dict())
919
+
920
+ # Projector for aggregator.
921
+ if args.aggregator_adapter:
922
+ image_proj_model = Resampler(
923
+ dim=1280,
924
+ depth=4,
925
+ dim_head=64,
926
+ heads=20,
927
+ num_queries=args.adapter_tokens,
928
+ embedding_dim=image_encoder.config.hidden_size,
929
+ output_dim=unet.config.cross_attention_dim,
930
+ ff_mult=4
931
+ )
932
+
933
+ init_adapter_in_unet(
934
+ aggregator,
935
+ image_proj_model,
936
+ pretrained_adapter_state_dict,
937
+ adapter_tokens=args.adapter_tokens,
938
+ )
939
+ del pretrained_adapter_state_dict
940
+
941
+ # Load LCM LoRA into unet.
942
+ if args.pretrained_lcm_lora_path is not None:
943
+ lora_state_dict, alpha_dict = StableDiffusionXLPipeline.lora_state_dict(args.pretrained_lcm_lora_path)
944
+ unet_state_dict = {
945
+ f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")
946
+ }
947
+ unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict)
948
+ lora_state_dict = dict()
949
+ for k, v in unet_state_dict.items():
950
+ if "ip" in k:
951
+ k = k.replace("attn2", "attn2.processor")
952
+ lora_state_dict[k] = v
953
+ else:
954
+ lora_state_dict[k] = v
955
+ if alpha_dict:
956
+ args.lora_alpha = next(iter(alpha_dict.values()))
957
+ else:
958
+ args.lora_alpha = 1
959
+ logger.info(f"Loaded LCM LoRA with alpha: {args.lora_alpha}")
960
+ # Create LoRA config, FIXME: now hard-coded.
961
+ lora_target_modules = [
962
+ "to_q",
963
+ "to_kv",
964
+ "0.to_out",
965
+ "attn1.to_k",
966
+ "attn1.to_v",
967
+ "to_k_ip",
968
+ "to_v_ip",
969
+ "ln_k_ip.linear",
970
+ "ln_v_ip.linear",
971
+ "to_out.0",
972
+ "proj_in",
973
+ "proj_out",
974
+ "ff.net.0.proj",
975
+ "ff.net.2",
976
+ "conv1",
977
+ "conv2",
978
+ "conv_shortcut",
979
+ "downsamplers.0.conv",
980
+ "upsamplers.0.conv",
981
+ "time_emb_proj",
982
+ ]
983
+ lora_config = LoraConfig(
984
+ r=args.lora_rank,
985
+ target_modules=lora_target_modules,
986
+ lora_alpha=args.lora_alpha,
987
+ lora_dropout=args.lora_dropout,
988
+ )
989
+
990
+ unet.add_adapter(lora_config)
991
+ if args.pretrained_lcm_lora_path is not None:
992
+ incompatible_keys = set_peft_model_state_dict(unet, lora_state_dict, adapter_name="default")
993
+ if incompatible_keys is not None:
994
+ # check only for unexpected keys
995
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
996
+ missing_keys = getattr(incompatible_keys, "missing_keys", None)
997
+ if unexpected_keys:
998
+ raise ValueError(
999
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
1000
+ f" {unexpected_keys}. "
1001
+ )
1002
+ for k in missing_keys:
1003
+ if "lora" in k:
1004
+ raise ValueError(
1005
+ f"Loading adapter weights from state_dict led to missing keys: {missing_keys}. "
1006
+ )
1007
+ lcm_scheduler = LCMSingleStepScheduler.from_config(noise_scheduler.config)
1008
+
1009
+ # Initialize training state.
1010
+ vae.requires_grad_(False)
1011
+ image_encoder.requires_grad_(False)
1012
+ text_encoder.requires_grad_(False)
1013
+ text_encoder_2.requires_grad_(False)
1014
+ unet.requires_grad_(False)
1015
+ aggregator.requires_grad_(False)
1016
+
1017
+ def unwrap_model(model):
1018
+ model = accelerator.unwrap_model(model)
1019
+ model = model._orig_mod if is_compiled_module(model) else model
1020
+ return model
1021
+
1022
+ # `accelerate` 0.16.0 will have better support for customized saving
1023
+ if args.save_only_adapter:
1024
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1025
+ def save_model_hook(models, weights, output_dir):
1026
+ if accelerator.is_main_process:
1027
+ for model in models:
1028
+ if isinstance(model, Aggregator):
1029
+ torch.save(model.state_dict(), os.path.join(output_dir, "aggregator_ckpt.pt"))
1030
+ weights.pop()
1031
+
1032
+ def load_model_hook(models, input_dir):
1033
+
1034
+ while len(models) > 0:
1035
+ # pop models so that they are not loaded again
1036
+ model = models.pop()
1037
+
1038
+ if isinstance(model, Aggregator):
1039
+ aggregator_state_dict = torch.load(os.path.join(input_dir, "aggregator_ckpt.pt"), map_location="cpu")
1040
+ model.load_state_dict(aggregator_state_dict)
1041
+
1042
+ accelerator.register_save_state_pre_hook(save_model_hook)
1043
+ accelerator.register_load_state_pre_hook(load_model_hook)
1044
+
1045
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
1046
+ # as these models are only used for inference, keeping weights in full precision is not required.
1047
+ weight_dtype = torch.float32
1048
+ if accelerator.mixed_precision == "fp16":
1049
+ weight_dtype = torch.float16
1050
+ elif accelerator.mixed_precision == "bf16":
1051
+ weight_dtype = torch.bfloat16
1052
+
1053
+ if args.enable_xformers_memory_efficient_attention:
1054
+ if is_xformers_available():
1055
+ import xformers
1056
+
1057
+ xformers_version = version.parse(xformers.__version__)
1058
+ if xformers_version == version.parse("0.0.16"):
1059
+ logger.warning(
1060
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1061
+ )
1062
+ unet.enable_xformers_memory_efficient_attention()
1063
+ else:
1064
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
1065
+
1066
+ if args.gradient_checkpointing:
1067
+ aggregator.enable_gradient_checkpointing()
1068
+ unet.enable_gradient_checkpointing()
1069
+
1070
+ # Check that all trainable models are in full precision
1071
+ low_precision_error_string = (
1072
+ " Please make sure to always have all model weights in full float32 precision when starting training - even if"
1073
+ " doing mixed precision training, copy of the weights should still be float32."
1074
+ )
1075
+
1076
+ if unwrap_model(aggregator).dtype != torch.float32:
1077
+ raise ValueError(
1078
+ f"aggregator loaded as datatype {unwrap_model(aggregator).dtype}. {low_precision_error_string}"
1079
+ )
1080
+
1081
+ # Enable TF32 for faster training on Ampere GPUs,
1082
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1083
+ if args.allow_tf32:
1084
+ torch.backends.cuda.matmul.allow_tf32 = True
1085
+
1086
+ if args.scale_lr:
1087
+ args.learning_rate = (
1088
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1089
+ )
1090
+
1091
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1092
+ if args.use_8bit_adam:
1093
+ try:
1094
+ import bitsandbytes as bnb
1095
+ except ImportError:
1096
+ raise ImportError(
1097
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1098
+ )
1099
+
1100
+ optimizer_class = bnb.optim.AdamW8bit
1101
+ else:
1102
+ optimizer_class = torch.optim.AdamW
1103
+
1104
+ # Optimizer creation
1105
+ ip_params, non_ip_params = seperate_ip_params_from_unet(unet)
1106
+ if args.optimize_adapter:
1107
+ unet_params = ip_params
1108
+ unet_frozen_params = non_ip_params
1109
+ else: # freeze all unet params
1110
+ unet_params = []
1111
+ unet_frozen_params = ip_params + non_ip_params
1112
+ assert len(unet_frozen_params) == len(list(unet.parameters()))
1113
+ params_to_optimize = [p for p in aggregator.parameters()]
1114
+ params_to_optimize += unet_params
1115
+ optimizer = optimizer_class(
1116
+ params_to_optimize,
1117
+ lr=args.learning_rate,
1118
+ betas=(args.adam_beta1, args.adam_beta2),
1119
+ weight_decay=args.adam_weight_decay,
1120
+ eps=args.adam_epsilon,
1121
+ )
1122
+
1123
+ # Instantiate Loss.
1124
+ losses_configs: LossesConfig = pyrallis.load(LossesConfig, open(args.losses_config_path, "r"))
1125
+ diffusion_losses = list()
1126
+ lcm_losses = list()
1127
+ for loss_config in losses_configs.diffusion_losses:
1128
+ logger.info(f"Using diffusion loss: {loss_config.name}")
1129
+ loss = namedtuple("loss", ["loss", "weight"])
1130
+ diffusion_losses.append(
1131
+ loss(loss=eval(loss_config.name)(
1132
+ visualize_every_k=loss_config.visualize_every_k,
1133
+ dtype=weight_dtype,
1134
+ accelerator=accelerator,
1135
+ **loss_config.init_params), weight=loss_config.weight)
1136
+ )
1137
+ for loss_config in losses_configs.lcm_losses:
1138
+ logger.info(f"Using lcm loss: {loss_config.name}")
1139
+ loss = namedtuple("loss", ["loss", "weight"])
1140
+ loss_class = eval(loss_config.name)
1141
+ lcm_losses.append(loss(loss=loss_class(visualize_every_k=loss_config.visualize_every_k,
1142
+ dtype=weight_dtype,
1143
+ accelerator=accelerator,
1144
+ dino_model=image_encoder,
1145
+ dino_preprocess=image_processor,
1146
+ **loss_config.init_params), weight=loss_config.weight))
1147
+
1148
+ # SDXL additional condition that will be added to time embedding.
1149
+ def compute_time_ids(original_size, crops_coords_top_left):
1150
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1151
+ target_size = (args.resolution, args.resolution)
1152
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1153
+ add_time_ids = torch.tensor([add_time_ids])
1154
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1155
+ return add_time_ids
1156
+
1157
+ # Text prompt embeddings.
1158
+ @torch.no_grad()
1159
+ def compute_embeddings(batch, text_encoders, tokenizers, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
1160
+ prompt_batch = batch[args.caption_column]
1161
+ if drop_idx is not None:
1162
+ for i in range(len(prompt_batch)):
1163
+ prompt_batch[i] = "" if drop_idx[i] else prompt_batch[i]
1164
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1165
+ prompt_batch, text_encoders, tokenizers, is_train
1166
+ )
1167
+
1168
+ add_time_ids = torch.cat(
1169
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1170
+ )
1171
+
1172
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1173
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1174
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
1175
+ unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
1176
+
1177
+ return prompt_embeds, unet_added_cond_kwargs
1178
+
1179
+ @torch.no_grad()
1180
+ def get_added_cond(batch, prompt_embeds, pooled_prompt_embeds, proportion_empty_prompts=0.0, drop_idx=None, is_train=True):
1181
+
1182
+ add_time_ids = torch.cat(
1183
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1184
+ )
1185
+
1186
+ prompt_embeds = prompt_embeds.to(accelerator.device)
1187
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1188
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
1189
+ unet_added_cond_kwargs = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
1190
+
1191
+ return prompt_embeds, unet_added_cond_kwargs
1192
+
1193
+ # Move pixels into latents.
1194
+ @torch.no_grad()
1195
+ def convert_to_latent(pixels):
1196
+ model_input = vae.encode(pixels).latent_dist.sample()
1197
+ model_input = model_input * vae.config.scaling_factor
1198
+ if args.pretrained_vae_model_name_or_path is None:
1199
+ model_input = model_input.to(weight_dtype)
1200
+ return model_input
1201
+
1202
+ # Helper functions for training loop.
1203
+ # if args.degradation_config_path is not None:
1204
+ # with open(args.degradation_config_path) as stream:
1205
+ # degradation_configs = yaml.safe_load(stream)
1206
+ # deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
1207
+ # else:
1208
+ deg_pipeline = RealESRGANDegradation(device=accelerator.device, resolution=args.resolution)
1209
+ compute_embeddings_fn = functools.partial(
1210
+ compute_embeddings,
1211
+ text_encoders=[text_encoder, text_encoder_2],
1212
+ tokenizers=[tokenizer, tokenizer_2],
1213
+ is_train=True,
1214
+ )
1215
+
1216
+ datasets = []
1217
+ datasets_name = []
1218
+ datasets_weights = []
1219
+ if args.data_config_path is not None:
1220
+ data_config: DataConfig = pyrallis.load(DataConfig, open(args.data_config_path, "r"))
1221
+ for single_dataset in data_config.datasets:
1222
+ datasets_weights.append(single_dataset.dataset_weight)
1223
+ datasets_name.append(single_dataset.dataset_folder)
1224
+ dataset_dir = os.path.join(args.train_data_dir, single_dataset.dataset_folder)
1225
+ image_dataset = get_train_dataset(dataset_dir, dataset_dir, args, accelerator)
1226
+ image_dataset = prepare_train_dataset(image_dataset, accelerator, deg_pipeline)
1227
+ datasets.append(image_dataset)
1228
+ # TODO: Validation dataset
1229
+ if data_config.val_dataset is not None:
1230
+ val_dataset = get_train_dataset(dataset_name, dataset_dir, args, accelerator)
1231
+ logger.info(f"Datasets mixing: {list(zip(datasets_name, datasets_weights))}")
1232
+
1233
+ # Mix training datasets.
1234
+ sampler_train = None
1235
+ if len(datasets) == 1:
1236
+ train_dataset = datasets[0]
1237
+ else:
1238
+ # Weighted each dataset
1239
+ train_dataset = torch.utils.data.ConcatDataset(datasets)
1240
+ dataset_weights = []
1241
+ for single_dataset, single_weight in zip(datasets, datasets_weights):
1242
+ dataset_weights.extend([len(train_dataset) / len(single_dataset) * single_weight] * len(single_dataset))
1243
+ sampler_train = torch.utils.data.WeightedRandomSampler(
1244
+ weights=dataset_weights,
1245
+ num_samples=len(dataset_weights)
1246
+ )
1247
+
1248
+ train_dataloader = torch.utils.data.DataLoader(
1249
+ train_dataset,
1250
+ batch_size=args.train_batch_size,
1251
+ sampler=sampler_train,
1252
+ shuffle=True if sampler_train is None else False,
1253
+ collate_fn=collate_fn,
1254
+ num_workers=args.dataloader_num_workers
1255
+ )
1256
+
1257
+ # We need to initialize the trackers we use, and also store our configuration.
1258
+ # The trackers initializes automatically on the main process.
1259
+ if accelerator.is_main_process:
1260
+ tracker_config = dict(vars(args))
1261
+
1262
+ # tensorboard cannot handle list types for config
1263
+ tracker_config.pop("validation_prompt")
1264
+ tracker_config.pop("validation_image")
1265
+
1266
+ accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
1267
+
1268
+ # Scheduler and math around the number of training steps.
1269
+ overrode_max_train_steps = False
1270
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1271
+ if args.max_train_steps is None:
1272
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1273
+ overrode_max_train_steps = True
1274
+
1275
+ lr_scheduler = get_scheduler(
1276
+ args.lr_scheduler,
1277
+ optimizer=optimizer,
1278
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1279
+ num_training_steps=args.max_train_steps,
1280
+ num_cycles=args.lr_num_cycles,
1281
+ power=args.lr_power,
1282
+ )
1283
+
1284
+ # Prepare everything with our `accelerator`.
1285
+ aggregator, unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1286
+ aggregator, unet, optimizer, train_dataloader, lr_scheduler
1287
+ )
1288
+
1289
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
1290
+ text_encoder_2.to(accelerator.device, dtype=weight_dtype)
1291
+
1292
+ # # cache empty prompts and move text encoders to cpu
1293
+ # empty_prompt_embeds, empty_pooled_prompt_embeds = encode_prompt(
1294
+ # [""]*args.train_batch_size, [text_encoder, text_encoder_2], [tokenizer, tokenizer_2], True
1295
+ # )
1296
+ # compute_embeddings_fn = functools.partial(
1297
+ # get_added_cond,
1298
+ # prompt_embeds=empty_prompt_embeds,
1299
+ # pooled_prompt_embeds=empty_pooled_prompt_embeds,
1300
+ # is_train=True,
1301
+ # )
1302
+ # text_encoder.to("cpu")
1303
+ # text_encoder_2.to("cpu")
1304
+
1305
+ # Move vae, unet and text_encoder to device and cast to `weight_dtype`.
1306
+ if args.pretrained_vae_model_name_or_path is None:
1307
+ # The VAE is fp32 to avoid NaN losses.
1308
+ vae.to(accelerator.device, dtype=torch.float32)
1309
+ else:
1310
+ vae.to(accelerator.device, dtype=weight_dtype)
1311
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
1312
+ if args.use_ema_adapter:
1313
+ # FIXME: prepare ema model
1314
+ # ema_encoder_hid_proj, ema_attn_procs_list = accelerator.prepare(ema_encoder_hid_proj, ema_attn_procs_list)
1315
+ ema_encoder_hid_proj.to(accelerator.device)
1316
+ ema_attn_procs_list.to(accelerator.device)
1317
+ for param in unet_frozen_params:
1318
+ param.data = param.data.to(dtype=weight_dtype)
1319
+ for param in unet_params:
1320
+ param.requires_grad_(True)
1321
+ unet.to(accelerator.device)
1322
+ aggregator.requires_grad_(True)
1323
+ aggregator.to(accelerator.device)
1324
+ importance_ratio = importance_ratio.to(accelerator.device)
1325
+
1326
+ # Final sanity check.
1327
+ for n, p in unet.named_parameters():
1328
+ assert not p.requires_grad, n
1329
+ if p.requires_grad:
1330
+ assert p.dtype == torch.float32, n
1331
+ else:
1332
+ assert p.dtype == weight_dtype, n
1333
+ for n, p in aggregator.named_parameters():
1334
+ if p.requires_grad: assert p.dtype == torch.float32, n
1335
+ else:
1336
+ raise RuntimeError(f"All parameters in aggregator should require grad. {n}")
1337
+ assert p.dtype == weight_dtype, n
1338
+ unwrap_model(unet).disable_adapters()
1339
+
1340
+ if args.sanity_check:
1341
+ if args.resume_from_checkpoint:
1342
+ if args.resume_from_checkpoint != "latest":
1343
+ path = os.path.basename(args.resume_from_checkpoint)
1344
+ else:
1345
+ # Get the most recent checkpoint
1346
+ dirs = os.listdir(args.output_dir)
1347
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1348
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1349
+ path = dirs[-1] if len(dirs) > 0 else None
1350
+
1351
+ if path is None:
1352
+ accelerator.print(
1353
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1354
+ )
1355
+ args.resume_from_checkpoint = None
1356
+ initial_global_step = 0
1357
+ else:
1358
+ accelerator.print(f"Resuming from checkpoint {path}")
1359
+ accelerator.load_state(os.path.join(args.output_dir, path))
1360
+
1361
+ if args.use_ema_adapter:
1362
+ unwrap_model(unet).set_attn_processor(ema_attn_procs)
1363
+ unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
1364
+ batch = next(iter(train_dataloader))
1365
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1366
+ log_validation(
1367
+ unwrap_model(unet), unwrap_model(aggregator), vae,
1368
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1369
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
1370
+ args, accelerator, weight_dtype, step=0, lq_img=lq_img, gt_img=gt_img, log_local=True
1371
+ )
1372
+ if args.use_ema_adapter:
1373
+ unwrap_model(unet).set_attn_processor(orig_attn_procs)
1374
+ unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
1375
+ for n, p in unet.named_parameters():
1376
+ if p.requires_grad: assert p.dtype == torch.float32, n
1377
+ else: assert p.dtype == weight_dtype, n
1378
+ for n, p in aggregator.named_parameters():
1379
+ if p.requires_grad: assert p.dtype == torch.float32, n
1380
+ else: assert p.dtype == weight_dtype, n
1381
+ print("PASS")
1382
+ exit()
1383
+
1384
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1385
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1386
+ if overrode_max_train_steps:
1387
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1388
+ # Afterwards we recalculate our number of training epochs
1389
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1390
+
1391
+ # Train!
1392
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1393
+
1394
+ logger.info("***** Running training *****")
1395
+ logger.info(f" Num examples = {len(train_dataset)}")
1396
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1397
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1398
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1399
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1400
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1401
+ logger.info(f" Optimization steps per epoch = {num_update_steps_per_epoch}")
1402
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1403
+ global_step = 0
1404
+ first_epoch = 0
1405
+
1406
+ # Potentially load in the weights and states from a previous save
1407
+ if args.resume_from_checkpoint:
1408
+ if args.resume_from_checkpoint != "latest":
1409
+ path = os.path.basename(args.resume_from_checkpoint)
1410
+ else:
1411
+ # Get the most recent checkpoint
1412
+ dirs = os.listdir(args.output_dir)
1413
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1414
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1415
+ path = dirs[-1] if len(dirs) > 0 else None
1416
+
1417
+ if path is None:
1418
+ accelerator.print(
1419
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1420
+ )
1421
+ args.resume_from_checkpoint = None
1422
+ initial_global_step = 0
1423
+ else:
1424
+ accelerator.print(f"Resuming from checkpoint {path}")
1425
+ accelerator.load_state(os.path.join(args.output_dir, path))
1426
+ global_step = int(path.split("-")[1])
1427
+
1428
+ initial_global_step = global_step
1429
+ first_epoch = global_step // num_update_steps_per_epoch
1430
+ else:
1431
+ initial_global_step = 0
1432
+
1433
+ progress_bar = tqdm(
1434
+ range(0, args.max_train_steps),
1435
+ initial=initial_global_step,
1436
+ desc="Steps",
1437
+ # Only show the progress bar once on each machine.
1438
+ disable=not accelerator.is_local_main_process,
1439
+ )
1440
+
1441
+ trainable_models = [aggregator, unet]
1442
+
1443
+ if args.gradient_checkpointing:
1444
+ # TODO: add vae
1445
+ checkpoint_models = []
1446
+ else:
1447
+ checkpoint_models = []
1448
+
1449
+ image_logs = None
1450
+ tic = time.time()
1451
+ for epoch in range(first_epoch, args.num_train_epochs):
1452
+ for step, batch in enumerate(train_dataloader):
1453
+ toc = time.time()
1454
+ io_time = toc - tic
1455
+ tic = time.time()
1456
+ for model in trainable_models + checkpoint_models:
1457
+ model.train()
1458
+ with accelerator.accumulate(*trainable_models):
1459
+ loss = torch.tensor(0.0)
1460
+
1461
+ # Drop conditions.
1462
+ rand_tensor = torch.rand(batch["images"].shape[0])
1463
+ drop_image_idx = rand_tensor < args.image_drop_rate
1464
+ drop_text_idx = (rand_tensor >= args.image_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate)
1465
+ drop_both_idx = (rand_tensor >= args.image_drop_rate + args.text_drop_rate) & (rand_tensor < args.image_drop_rate + args.text_drop_rate + args.cond_drop_rate)
1466
+ drop_image_idx = drop_image_idx | drop_both_idx
1467
+ drop_text_idx = drop_text_idx | drop_both_idx
1468
+
1469
+ # Get LQ embeddings
1470
+ with torch.no_grad():
1471
+ lq_img, gt_img = deg_pipeline(batch["images"], (batch["kernel"], batch["kernel2"], batch["sinc_kernel"]))
1472
+ lq_pt = image_processor(
1473
+ images=lq_img*0.5+0.5,
1474
+ do_rescale=False, return_tensors="pt"
1475
+ ).pixel_values
1476
+
1477
+ # Move inputs to latent space.
1478
+ gt_img = gt_img.to(dtype=vae.dtype)
1479
+ lq_img = lq_img.to(dtype=vae.dtype)
1480
+ model_input = convert_to_latent(gt_img)
1481
+ lq_latent = convert_to_latent(lq_img)
1482
+ if args.pretrained_vae_model_name_or_path is None:
1483
+ model_input = model_input.to(weight_dtype)
1484
+ lq_latent = lq_latent.to(weight_dtype)
1485
+
1486
+ # Process conditions.
1487
+ image_embeds = prepare_training_image_embeds(
1488
+ image_encoder, image_processor,
1489
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1490
+ device=accelerator.device, drop_rate=args.image_drop_rate, output_hidden_state=args.image_encoder_hidden_feature,
1491
+ idx_to_replace=drop_image_idx
1492
+ )
1493
+ prompt_embeds_input, added_conditions = compute_embeddings_fn(batch, drop_idx=drop_text_idx)
1494
+
1495
+ # Sample noise that we'll add to the latents.
1496
+ noise = torch.randn_like(model_input)
1497
+ bsz = model_input.shape[0]
1498
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device)
1499
+
1500
+ # Add noise to the model input according to the noise magnitude at each timestep
1501
+ # (this is the forward diffusion process)
1502
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1503
+ loss_weights = extract_into_tensor(importance_ratio, timesteps, noise.shape) if args.importance_sampling else None
1504
+
1505
+ if args.CFG_scale > 1.0:
1506
+ # Process negative conditions.
1507
+ drop_idx = torch.ones_like(drop_image_idx)
1508
+ neg_image_embeds = prepare_training_image_embeds(
1509
+ image_encoder, image_processor,
1510
+ ip_adapter_image=lq_pt, ip_adapter_image_embeds=None,
1511
+ device=accelerator.device, drop_rate=1.0, output_hidden_state=args.image_encoder_hidden_feature,
1512
+ idx_to_replace=drop_idx
1513
+ )
1514
+ neg_prompt_embeds_input, neg_added_conditions = compute_embeddings_fn(batch, drop_idx=drop_idx)
1515
+ previewer_model_input = torch.cat([noisy_model_input] * 2)
1516
+ previewer_timesteps = torch.cat([timesteps] * 2)
1517
+ previewer_prompt_embeds = torch.cat([neg_prompt_embeds_input, prompt_embeds_input], dim=0)
1518
+ previewer_added_conditions = {}
1519
+ for k in added_conditions.keys():
1520
+ previewer_added_conditions[k] = torch.cat([neg_added_conditions[k], added_conditions[k]], dim=0)
1521
+ previewer_image_embeds = []
1522
+ for neg_image_embed, image_embed in zip(neg_image_embeds, image_embeds):
1523
+ previewer_image_embeds.append(torch.cat([neg_image_embed, image_embed], dim=0))
1524
+ previewer_added_conditions["image_embeds"] = previewer_image_embeds
1525
+ else:
1526
+ previewer_model_input = noisy_model_input
1527
+ previewer_timesteps = timesteps
1528
+ previewer_prompt_embeds = prompt_embeds_input
1529
+ previewer_added_conditions = {}
1530
+ for k in added_conditions.keys():
1531
+ previewer_added_conditions[k] = added_conditions[k]
1532
+ previewer_added_conditions["image_embeds"] = image_embeds
1533
+
1534
+ # Get LCM preview latent
1535
+ if args.use_ema_adapter:
1536
+ orig_encoder_hid_proj = unet.encoder_hid_proj
1537
+ orig_attn_procs = unet.attn_processors
1538
+ _ema_attn_procs = copy_dict(ema_attn_procs)
1539
+ unwrap_model(unet).set_attn_processor(_ema_attn_procs)
1540
+ unwrap_model(unet).encoder_hid_proj = ema_encoder_hid_proj
1541
+ unwrap_model(unet).enable_adapters()
1542
+ preview_noise = unet(
1543
+ previewer_model_input, previewer_timesteps,
1544
+ encoder_hidden_states=previewer_prompt_embeds,
1545
+ added_cond_kwargs=previewer_added_conditions,
1546
+ return_dict=False
1547
+ )[0]
1548
+ if args.CFG_scale > 1.0:
1549
+ preview_noise_uncond, preview_noise_cond = preview_noise.chunk(2)
1550
+ cfg_scale = 1.0 + torch.rand_like(timesteps, dtype=weight_dtype) * (args.CFG_scale-1.0)
1551
+ cfg_scale = cfg_scale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
1552
+ preview_noise = preview_noise_uncond + cfg_scale * (preview_noise_cond - preview_noise_uncond)
1553
+ preview_latents = lcm_scheduler.step(
1554
+ preview_noise,
1555
+ timesteps,
1556
+ noisy_model_input,
1557
+ return_dict=False
1558
+ )[0]
1559
+ unwrap_model(unet).disable_adapters()
1560
+ if args.use_ema_adapter:
1561
+ unwrap_model(unet).set_attn_processor(orig_attn_procs)
1562
+ unwrap_model(unet).encoder_hid_proj = orig_encoder_hid_proj
1563
+ preview_error_latent = F.mse_loss(preview_latents, model_input).cpu().item()
1564
+ preview_error_noise = F.mse_loss(preview_noise, noise).cpu().item()
1565
+
1566
+ # # Add fresh noise
1567
+ # if args.noisy_encoder_input:
1568
+ # aggregator_noise = torch.randn_like(preview_latents)
1569
+ # aggregator_input = noise_scheduler.add_noise(preview_latents, aggregator_noise, timesteps)
1570
+
1571
+ down_block_res_samples, mid_block_res_sample = aggregator(
1572
+ lq_latent,
1573
+ timesteps,
1574
+ encoder_hidden_states=prompt_embeds_input,
1575
+ added_cond_kwargs=added_conditions,
1576
+ controlnet_cond=preview_latents,
1577
+ conditioning_scale=1.0,
1578
+ return_dict=False,
1579
+ )
1580
+
1581
+ # UNet denoise.
1582
+ added_conditions["image_embeds"] = image_embeds
1583
+ model_pred = unet(
1584
+ noisy_model_input,
1585
+ timesteps,
1586
+ encoder_hidden_states=prompt_embeds_input,
1587
+ added_cond_kwargs=added_conditions,
1588
+ down_block_additional_residuals=[
1589
+ sample.to(dtype=weight_dtype) for sample in down_block_res_samples
1590
+ ],
1591
+ mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
1592
+ return_dict=False
1593
+ )[0]
1594
+
1595
+ diffusion_loss_arguments = {
1596
+ "target": noise,
1597
+ "predict": model_pred,
1598
+ "prompt_embeddings_input": prompt_embeds_input,
1599
+ "timesteps": timesteps,
1600
+ "weights": loss_weights,
1601
+ }
1602
+
1603
+ loss_dict = dict()
1604
+ for loss_config in diffusion_losses:
1605
+ non_weighted_loss = loss_config.loss(**diffusion_loss_arguments, accelerator=accelerator)
1606
+ loss = loss + non_weighted_loss * loss_config.weight
1607
+ loss_dict[loss_config.loss.__class__.__name__] = non_weighted_loss.item()
1608
+
1609
+ accelerator.backward(loss)
1610
+ if accelerator.sync_gradients:
1611
+ accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
1612
+ optimizer.step()
1613
+ lr_scheduler.step()
1614
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1615
+
1616
+ toc = time.time()
1617
+ forward_time = toc - tic
1618
+ tic = toc
1619
+
1620
+ # Checks if the accelerator has performed an optimization step behind the scenes
1621
+ if accelerator.sync_gradients:
1622
+ progress_bar.update(1)
1623
+ global_step += 1
1624
+
1625
+ if global_step % args.ema_update_steps == 0:
1626
+ if args.use_ema_adapter:
1627
+ update_ema_model(ema_encoder_hid_proj, orig_encoder_hid_proj, ema_beta)
1628
+ update_ema_model(ema_attn_procs_list, orig_attn_procs_list, ema_beta)
1629
+
1630
+ if accelerator.is_main_process:
1631
+ if global_step % args.checkpointing_steps == 0:
1632
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1633
+ if args.checkpoints_total_limit is not None:
1634
+ checkpoints = os.listdir(args.output_dir)
1635
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1636
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1637
+
1638
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1639
+ if len(checkpoints) >= args.checkpoints_total_limit:
1640
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1641
+ removing_checkpoints = checkpoints[0:num_to_remove]
1642
+
1643
+ logger.info(
1644
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1645
+ )
1646
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1647
+
1648
+ for removing_checkpoint in removing_checkpoints:
1649
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1650
+ shutil.rmtree(removing_checkpoint)
1651
+
1652
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1653
+ accelerator.save_state(save_path)
1654
+ logger.info(f"Saved state to {save_path}")
1655
+
1656
+ if global_step % args.validation_steps == 0:
1657
+ image_logs = log_validation(
1658
+ unwrap_model(unet), unwrap_model(aggregator), vae,
1659
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1660
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
1661
+ args, accelerator, weight_dtype, global_step, lq_img.detach().clone(), gt_img.detach().clone()
1662
+ )
1663
+
1664
+ logs = {}
1665
+ logs.update(loss_dict)
1666
+ logs.update(
1667
+ {"preview_error_latent": preview_error_latent, "preview_error_noise": preview_error_noise,
1668
+ "lr": lr_scheduler.get_last_lr()[0],
1669
+ "forward_time": forward_time, "io_time": io_time}
1670
+ )
1671
+ progress_bar.set_postfix(**logs)
1672
+ accelerator.log(logs, step=global_step)
1673
+ tic = time.time()
1674
+
1675
+ if global_step >= args.max_train_steps:
1676
+ break
1677
+
1678
+ # Create the pipeline using using the trained modules and save it.
1679
+ accelerator.wait_for_everyone()
1680
+ if accelerator.is_main_process:
1681
+ accelerator.save_state(save_path, safe_serialization=False)
1682
+ # Run a final round of validation.
1683
+ # Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
1684
+ image_logs = None
1685
+ if args.validation_image is not None:
1686
+ image_logs = log_validation(
1687
+ unwrap_model(unet), unwrap_model(aggregator), vae,
1688
+ text_encoder, text_encoder_2, tokenizer, tokenizer_2,
1689
+ noise_scheduler, lcm_scheduler, image_encoder, image_processor, deg_pipeline,
1690
+ args, accelerator, weight_dtype, global_step,
1691
+ )
1692
+
1693
+ accelerator.end_training()
1694
+
1695
+
1696
+ if __name__ == "__main__":
1697
+ args = parse_args()
1698
+ main(args)
train_stage2_aggregator.sh ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Stage 2: train aggregator
2
+ accelerate launch --num_processes <num_of_gpus> train_stage2_aggregator.py \
3
+ --output_dir <your/output/path> \
4
+ --train_data_dir <your/data/path> \
5
+ --logging_dir <your/logging/path> \
6
+ --pretrained_model_name_or_path <your/sdxl/path> \
7
+ --feature_extractor_path <your/dinov2/path> \
8
+ --pretrained_adapter_model_path <your/dcp/path> \
9
+ --pretrained_lcm_lora_path <your/previewer_lora/path> \
10
+ --losses_config_path config_files/losses.yaml \
11
+ --data_config_path config_files/IR_dataset.yaml \
12
+ --image_drop_rate 0.0 \
13
+ --text_drop_rate 0.85 \
14
+ --cond_drop_rate 0.15 \
15
+ --save_only_adapter \
16
+ --gradient_checkpointing \
17
+ --mixed_precision fp16 \
18
+ --train_batch_size 6 \
19
+ --gradient_accumulation_steps 2 \
20
+ --learning_rate 1e-4 \
21
+ --lr_warmup_steps 1000 \
22
+ --lr_scheduler cosine \
23
+ --lr_num_cycles 1 \
24
+ --resume_from_checkpoint latest
utils/degradation_pipeline.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ from torch.utils import data as data
7
+
8
+ from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels
9
+ from basicsr.data.transforms import augment
10
+ from basicsr.utils import img2tensor, DiffJPEG, USMSharp
11
+ from basicsr.utils.img_process_util import filter2D
12
+ from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
13
+ from basicsr.data.transforms import paired_random_crop
14
+
15
+ AUGMENT_OPT = {
16
+ 'use_hflip': False,
17
+ 'use_rot': False
18
+ }
19
+
20
+ KERNEL_OPT = {
21
+ 'blur_kernel_size': 21,
22
+ 'kernel_list': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
23
+ 'kernel_prob': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
24
+ 'sinc_prob': 0.1,
25
+ 'blur_sigma': [0.2, 3],
26
+ 'betag_range': [0.5, 4],
27
+ 'betap_range': [1, 2],
28
+
29
+ 'blur_kernel_size2': 21,
30
+ 'kernel_list2': ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso'],
31
+ 'kernel_prob2': [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
32
+ 'sinc_prob2': 0.1,
33
+ 'blur_sigma2': [0.2, 1.5],
34
+ 'betag_range2': [0.5, 4],
35
+ 'betap_range2': [1, 2],
36
+ 'final_sinc_prob': 0.8,
37
+ }
38
+
39
+ DEGRADE_OPT = {
40
+ 'resize_prob': [0.2, 0.7, 0.1], # up, down, keep
41
+ 'resize_range': [0.15, 1.5],
42
+ 'gaussian_noise_prob': 0.5,
43
+ 'noise_range': [1, 30],
44
+ 'poisson_scale_range': [0.05, 3],
45
+ 'gray_noise_prob': 0.4,
46
+ 'jpeg_range': [30, 95],
47
+
48
+ # the second degradation process
49
+ 'second_blur_prob': 0.8,
50
+ 'resize_prob2': [0.3, 0.4, 0.3], # up, down, keep
51
+ 'resize_range2': [0.3, 1.2],
52
+ 'gaussian_noise_prob2': 0.5,
53
+ 'noise_range2': [1, 25],
54
+ 'poisson_scale_range2': [0.05, 2.5],
55
+ 'gray_noise_prob2': 0.4,
56
+ 'jpeg_range2': [30, 95],
57
+
58
+ 'gt_size': 512,
59
+ 'no_degradation_prob': 0.01,
60
+ 'use_usm': True,
61
+ 'sf': 4,
62
+ 'random_size': False,
63
+ 'resize_lq': True
64
+ }
65
+
66
+ class RealESRGANDegradation:
67
+
68
+ def __init__(self, augment_opt=None, kernel_opt=None, degrade_opt=None, device='cuda', resolution=None):
69
+ if augment_opt is None:
70
+ augment_opt = AUGMENT_OPT
71
+ self.augment_opt = augment_opt
72
+ if kernel_opt is None:
73
+ kernel_opt = KERNEL_OPT
74
+ self.kernel_opt = kernel_opt
75
+ if degrade_opt is None:
76
+ degrade_opt = DEGRADE_OPT
77
+ self.degrade_opt = degrade_opt
78
+ if resolution is not None:
79
+ self.degrade_opt['gt_size'] = resolution
80
+ self.device = device
81
+
82
+ self.jpeger = DiffJPEG(differentiable=False).to(self.device)
83
+ self.usm_sharpener = USMSharp().to(self.device)
84
+
85
+ # blur settings for the first degradation
86
+ self.blur_kernel_size = kernel_opt['blur_kernel_size']
87
+ self.kernel_list = kernel_opt['kernel_list']
88
+ self.kernel_prob = kernel_opt['kernel_prob'] # a list for each kernel probability
89
+ self.blur_sigma = kernel_opt['blur_sigma']
90
+ self.betag_range = kernel_opt['betag_range'] # betag used in generalized Gaussian blur kernels
91
+ self.betap_range = kernel_opt['betap_range'] # betap used in plateau blur kernels
92
+ self.sinc_prob = kernel_opt['sinc_prob'] # the probability for sinc filters
93
+
94
+ # blur settings for the second degradation
95
+ self.blur_kernel_size2 = kernel_opt['blur_kernel_size2']
96
+ self.kernel_list2 = kernel_opt['kernel_list2']
97
+ self.kernel_prob2 = kernel_opt['kernel_prob2']
98
+ self.blur_sigma2 = kernel_opt['blur_sigma2']
99
+ self.betag_range2 = kernel_opt['betag_range2']
100
+ self.betap_range2 = kernel_opt['betap_range2']
101
+ self.sinc_prob2 = kernel_opt['sinc_prob2']
102
+
103
+ # a final sinc filter
104
+ self.final_sinc_prob = kernel_opt['final_sinc_prob']
105
+
106
+ self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
107
+ # TODO: kernel range is now hard-coded, should be in the configure file
108
+ self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect
109
+ self.pulse_tensor[10, 10] = 1
110
+
111
+ def get_kernel(self):
112
+
113
+ # ------------------------ Generate kernels (used in the first degradation) ------------------------ #
114
+ kernel_size = random.choice(self.kernel_range)
115
+ if np.random.uniform() < self.kernel_opt['sinc_prob']:
116
+ # this sinc filter setting is for kernels ranging from [7, 21]
117
+ if kernel_size < 13:
118
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
119
+ else:
120
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
121
+ kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
122
+ else:
123
+ kernel = random_mixed_kernels(
124
+ self.kernel_list,
125
+ self.kernel_prob,
126
+ kernel_size,
127
+ self.blur_sigma,
128
+ self.blur_sigma, [-math.pi, math.pi],
129
+ self.betag_range,
130
+ self.betap_range,
131
+ noise_range=None)
132
+ # pad kernel
133
+ pad_size = (21 - kernel_size) // 2
134
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
135
+
136
+ # ------------------------ Generate kernels (used in the second degradation) ------------------------ #
137
+ kernel_size = random.choice(self.kernel_range)
138
+ if np.random.uniform() < self.kernel_opt['sinc_prob2']:
139
+ if kernel_size < 13:
140
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
141
+ else:
142
+ omega_c = np.random.uniform(np.pi / 5, np.pi)
143
+ kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
144
+ else:
145
+ kernel2 = random_mixed_kernels(
146
+ self.kernel_list2,
147
+ self.kernel_prob2,
148
+ kernel_size,
149
+ self.blur_sigma2,
150
+ self.blur_sigma2, [-math.pi, math.pi],
151
+ self.betag_range2,
152
+ self.betap_range2,
153
+ noise_range=None)
154
+
155
+ # pad kernel
156
+ pad_size = (21 - kernel_size) // 2
157
+ kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
158
+
159
+ # ------------------------------------- the final sinc kernel ------------------------------------- #
160
+ if np.random.uniform() < self.kernel_opt['final_sinc_prob']:
161
+ kernel_size = random.choice(self.kernel_range)
162
+ omega_c = np.random.uniform(np.pi / 3, np.pi)
163
+ sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
164
+ sinc_kernel = torch.FloatTensor(sinc_kernel)
165
+ else:
166
+ sinc_kernel = self.pulse_tensor
167
+
168
+ # BGR to RGB, HWC to CHW, numpy to tensor
169
+ kernel = torch.FloatTensor(kernel)
170
+ kernel2 = torch.FloatTensor(kernel2)
171
+
172
+ return (kernel, kernel2, sinc_kernel)
173
+
174
+ @torch.no_grad()
175
+ def __call__(self, img_gt, kernels=None):
176
+ '''
177
+ :param: img_gt: BCHW, RGB, [0, 1] float32 tensor
178
+ '''
179
+ if kernels is None:
180
+ kernel = []
181
+ kernel2 = []
182
+ sinc_kernel = []
183
+ for _ in range(img_gt.shape[0]):
184
+ k, k2, sk = self.get_kernel()
185
+ kernel.append(k)
186
+ kernel2.append(k2)
187
+ sinc_kernel.append(sk)
188
+ kernel = torch.stack(kernel)
189
+ kernel2 = torch.stack(kernel2)
190
+ sinc_kernel = torch.stack(sinc_kernel)
191
+ else:
192
+ # kernels created in dataset.
193
+ kernel, kernel2, sinc_kernel = kernels
194
+
195
+ # ----------------------- Pre-process ----------------------- #
196
+ im_gt = img_gt.to(self.device)
197
+ if self.degrade_opt['use_usm']:
198
+ im_gt = self.usm_sharpener(im_gt)
199
+ im_gt = im_gt.to(memory_format=torch.contiguous_format).float()
200
+ kernel = kernel.to(self.device)
201
+ kernel2 = kernel2.to(self.device)
202
+ sinc_kernel = sinc_kernel.to(self.device)
203
+ ori_h, ori_w = im_gt.size()[2:4]
204
+
205
+ # ----------------------- The first degradation process ----------------------- #
206
+ # blur
207
+ out = filter2D(im_gt, kernel)
208
+ # random resize
209
+ updown_type = random.choices(
210
+ ['up', 'down', 'keep'],
211
+ self.degrade_opt['resize_prob'],
212
+ )[0]
213
+ if updown_type == 'up':
214
+ scale = random.uniform(1, self.degrade_opt['resize_range'][1])
215
+ elif updown_type == 'down':
216
+ scale = random.uniform(self.degrade_opt['resize_range'][0], 1)
217
+ else:
218
+ scale = 1
219
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
220
+ out = torch.nn.functional.interpolate(out, scale_factor=scale, mode=mode)
221
+ # add noise
222
+ gray_noise_prob = self.degrade_opt['gray_noise_prob']
223
+ if random.random() < self.degrade_opt['gaussian_noise_prob']:
224
+ out = random_add_gaussian_noise_pt(
225
+ out,
226
+ sigma_range=self.degrade_opt['noise_range'],
227
+ clip=True,
228
+ rounds=False,
229
+ gray_prob=gray_noise_prob,
230
+ )
231
+ else:
232
+ out = random_add_poisson_noise_pt(
233
+ out,
234
+ scale_range=self.degrade_opt['poisson_scale_range'],
235
+ gray_prob=gray_noise_prob,
236
+ clip=True,
237
+ rounds=False)
238
+ # JPEG compression
239
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range'])
240
+ out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
241
+ out = self.jpeger(out, quality=jpeg_p)
242
+
243
+ # ----------------------- The second degradation process ----------------------- #
244
+ # blur
245
+ if random.random() < self.degrade_opt['second_blur_prob']:
246
+ out = out.contiguous()
247
+ out = filter2D(out, kernel2)
248
+ # random resize
249
+ updown_type = random.choices(
250
+ ['up', 'down', 'keep'],
251
+ self.degrade_opt['resize_prob2'],
252
+ )[0]
253
+ if updown_type == 'up':
254
+ scale = random.uniform(1, self.degrade_opt['resize_range2'][1])
255
+ elif updown_type == 'down':
256
+ scale = random.uniform(self.degrade_opt['resize_range2'][0], 1)
257
+ else:
258
+ scale = 1
259
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
260
+ out = torch.nn.functional.interpolate(
261
+ out,
262
+ size=(int(ori_h / self.degrade_opt['sf'] * scale),
263
+ int(ori_w / self.degrade_opt['sf'] * scale)),
264
+ mode=mode,
265
+ )
266
+ # add noise
267
+ gray_noise_prob = self.degrade_opt['gray_noise_prob2']
268
+ if random.random() < self.degrade_opt['gaussian_noise_prob2']:
269
+ out = random_add_gaussian_noise_pt(
270
+ out,
271
+ sigma_range=self.degrade_opt['noise_range2'],
272
+ clip=True,
273
+ rounds=False,
274
+ gray_prob=gray_noise_prob,
275
+ )
276
+ else:
277
+ out = random_add_poisson_noise_pt(
278
+ out,
279
+ scale_range=self.degrade_opt['poisson_scale_range2'],
280
+ gray_prob=gray_noise_prob,
281
+ clip=True,
282
+ rounds=False,
283
+ )
284
+
285
+ # JPEG compression + the final sinc filter
286
+ # We also need to resize images to desired sizes. We group [resize back + sinc filter] together
287
+ # as one operation.
288
+ # We consider two orders:
289
+ # 1. [resize back + sinc filter] + JPEG compression
290
+ # 2. JPEG compression + [resize back + sinc filter]
291
+ # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
292
+ if random.random() < 0.5:
293
+ # resize back + the final sinc filter
294
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
295
+ out = torch.nn.functional.interpolate(
296
+ out,
297
+ size=(ori_h // self.degrade_opt['sf'],
298
+ ori_w // self.degrade_opt['sf']),
299
+ mode=mode,
300
+ )
301
+ out = out.contiguous()
302
+ out = filter2D(out, sinc_kernel)
303
+ # JPEG compression
304
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
305
+ out = torch.clamp(out, 0, 1)
306
+ out = self.jpeger(out, quality=jpeg_p)
307
+ else:
308
+ # JPEG compression
309
+ jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.degrade_opt['jpeg_range2'])
310
+ out = torch.clamp(out, 0, 1)
311
+ out = self.jpeger(out, quality=jpeg_p)
312
+ # resize back + the final sinc filter
313
+ mode = random.choice(['area', 'bilinear', 'bicubic'])
314
+ out = torch.nn.functional.interpolate(
315
+ out,
316
+ size=(ori_h // self.degrade_opt['sf'],
317
+ ori_w // self.degrade_opt['sf']),
318
+ mode=mode,
319
+ )
320
+ out = out.contiguous()
321
+ out = filter2D(out, sinc_kernel)
322
+
323
+ # clamp and round
324
+ im_lq = torch.clamp(out, 0, 1.0)
325
+
326
+ # random crop
327
+ gt_size = self.degrade_opt['gt_size']
328
+ im_gt, im_lq = paired_random_crop(im_gt, im_lq, gt_size, self.degrade_opt['sf'])
329
+
330
+ if self.degrade_opt['resize_lq']:
331
+ im_lq = torch.nn.functional.interpolate(
332
+ im_lq,
333
+ size=(im_gt.size(-2),
334
+ im_gt.size(-1)),
335
+ mode='bicubic',
336
+ )
337
+
338
+ if random.random() < self.degrade_opt['no_degradation_prob'] or torch.isnan(im_lq).any():
339
+ im_lq = im_gt
340
+
341
+ # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
342
+ im_lq = im_lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract
343
+ im_lq = im_lq*2 - 1.0
344
+ im_gt = im_gt*2 - 1.0
345
+
346
+ if self.degrade_opt['random_size']:
347
+ raise NotImplementedError
348
+ im_lq, im_gt = self.randn_cropinput(im_lq, im_gt)
349
+
350
+ im_lq = torch.clamp(im_lq, -1.0, 1.0)
351
+ im_gt = torch.clamp(im_gt, -1.0, 1.0)
352
+
353
+ return (im_lq, im_gt)
utils/matlab_cp2tform.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jul 11 06:54:28 2017
4
+
5
+ @author: zhaoyafei
6
+ """
7
+
8
+ import numpy as np
9
+ from numpy.linalg import inv, norm, lstsq
10
+ from numpy.linalg import matrix_rank as rank
11
+
12
+ class MatlabCp2tormException(Exception):
13
+ def __str__(self):
14
+ return 'In File {}:{}'.format(
15
+ __file__, super.__str__(self))
16
+
17
+ def tformfwd(trans, uv):
18
+ """
19
+ Function:
20
+ ----------
21
+ apply affine transform 'trans' to uv
22
+
23
+ Parameters:
24
+ ----------
25
+ @trans: 3x3 np.array
26
+ transform matrix
27
+ @uv: Kx2 np.array
28
+ each row is a pair of coordinates (x, y)
29
+
30
+ Returns:
31
+ ----------
32
+ @xy: Kx2 np.array
33
+ each row is a pair of transformed coordinates (x, y)
34
+ """
35
+ uv = np.hstack((
36
+ uv, np.ones((uv.shape[0], 1))
37
+ ))
38
+ xy = np.dot(uv, trans)
39
+ xy = xy[:, 0:-1]
40
+ return xy
41
+
42
+
43
+ def tforminv(trans, uv):
44
+ """
45
+ Function:
46
+ ----------
47
+ apply the inverse of affine transform 'trans' to uv
48
+
49
+ Parameters:
50
+ ----------
51
+ @trans: 3x3 np.array
52
+ transform matrix
53
+ @uv: Kx2 np.array
54
+ each row is a pair of coordinates (x, y)
55
+
56
+ Returns:
57
+ ----------
58
+ @xy: Kx2 np.array
59
+ each row is a pair of inverse-transformed coordinates (x, y)
60
+ """
61
+ Tinv = inv(trans)
62
+ xy = tformfwd(Tinv, uv)
63
+ return xy
64
+
65
+
66
+ def findNonreflectiveSimilarity(uv, xy, options=None):
67
+
68
+ options = {'K': 2}
69
+
70
+ K = options['K']
71
+ M = xy.shape[0]
72
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
73
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
74
+ # print('--->x, y:\n', x, y
75
+
76
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
77
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
78
+ X = np.vstack((tmp1, tmp2))
79
+ # print('--->X.shape: ', X.shape
80
+ # print('X:\n', X
81
+
82
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
83
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
84
+ U = np.vstack((u, v))
85
+ # print('--->U.shape: ', U.shape
86
+ # print('U:\n', U
87
+
88
+ # We know that X * r = U
89
+ if rank(X) >= 2 * K:
90
+ r, _, _, _ = lstsq(X, U)
91
+ r = np.squeeze(r)
92
+ else:
93
+ raise Exception('cp2tform:twoUniquePointsReq')
94
+
95
+ # print('--->r:\n', r
96
+
97
+ sc = r[0]
98
+ ss = r[1]
99
+ tx = r[2]
100
+ ty = r[3]
101
+
102
+ Tinv = np.array([
103
+ [sc, -ss, 0],
104
+ [ss, sc, 0],
105
+ [tx, ty, 1]
106
+ ])
107
+
108
+ # print('--->Tinv:\n', Tinv
109
+
110
+ T = inv(Tinv)
111
+ # print('--->T:\n', T
112
+
113
+ T[:, 2] = np.array([0, 0, 1])
114
+
115
+ return T, Tinv
116
+
117
+
118
+ def findSimilarity(uv, xy, options=None):
119
+
120
+ options = {'K': 2}
121
+
122
+ # uv = np.array(uv)
123
+ # xy = np.array(xy)
124
+
125
+ # Solve for trans1
126
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
127
+
128
+ # Solve for trans2
129
+
130
+ # manually reflect the xy data across the Y-axis
131
+ xyR = xy
132
+ xyR[:, 0] = -1 * xyR[:, 0]
133
+
134
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
135
+
136
+ # manually reflect the tform to undo the reflection done on xyR
137
+ TreflectY = np.array([
138
+ [-1, 0, 0],
139
+ [0, 1, 0],
140
+ [0, 0, 1]
141
+ ])
142
+
143
+ trans2 = np.dot(trans2r, TreflectY)
144
+
145
+ # Figure out if trans1 or trans2 is better
146
+ xy1 = tformfwd(trans1, uv)
147
+ norm1 = norm(xy1 - xy)
148
+
149
+ xy2 = tformfwd(trans2, uv)
150
+ norm2 = norm(xy2 - xy)
151
+
152
+ if norm1 <= norm2:
153
+ return trans1, trans1_inv
154
+ else:
155
+ trans2_inv = inv(trans2)
156
+ return trans2, trans2_inv
157
+
158
+
159
+ def get_similarity_transform(src_pts, dst_pts, reflective=True):
160
+ """
161
+ Function:
162
+ ----------
163
+ Find Similarity Transform Matrix 'trans':
164
+ u = src_pts[:, 0]
165
+ v = src_pts[:, 1]
166
+ x = dst_pts[:, 0]
167
+ y = dst_pts[:, 1]
168
+ [x, y, 1] = [u, v, 1] * trans
169
+
170
+ Parameters:
171
+ ----------
172
+ @src_pts: Kx2 np.array
173
+ source points, each row is a pair of coordinates (x, y)
174
+ @dst_pts: Kx2 np.array
175
+ destination points, each row is a pair of transformed
176
+ coordinates (x, y)
177
+ @reflective: True or False
178
+ if True:
179
+ use reflective similarity transform
180
+ else:
181
+ use non-reflective similarity transform
182
+
183
+ Returns:
184
+ ----------
185
+ @trans: 3x3 np.array
186
+ transform matrix from uv to xy
187
+ trans_inv: 3x3 np.array
188
+ inverse of trans, transform matrix from xy to uv
189
+ """
190
+
191
+ if reflective:
192
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
193
+ else:
194
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
195
+
196
+ return trans, trans_inv
197
+
198
+
199
+ def cvt_tform_mat_for_cv2(trans):
200
+ """
201
+ Function:
202
+ ----------
203
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
204
+ directly used by cv2.warpAffine():
205
+ u = src_pts[:, 0]
206
+ v = src_pts[:, 1]
207
+ x = dst_pts[:, 0]
208
+ y = dst_pts[:, 1]
209
+ [x, y].T = cv_trans * [u, v, 1].T
210
+
211
+ Parameters:
212
+ ----------
213
+ @trans: 3x3 np.array
214
+ transform matrix from uv to xy
215
+
216
+ Returns:
217
+ ----------
218
+ @cv2_trans: 2x3 np.array
219
+ transform matrix from src_pts to dst_pts, could be directly used
220
+ for cv2.warpAffine()
221
+ """
222
+ cv2_trans = trans[:, 0:2].T
223
+
224
+ return cv2_trans
225
+
226
+
227
+ def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
228
+ """
229
+ Function:
230
+ ----------
231
+ Find Similarity Transform Matrix 'cv2_trans' which could be
232
+ directly used by cv2.warpAffine():
233
+ u = src_pts[:, 0]
234
+ v = src_pts[:, 1]
235
+ x = dst_pts[:, 0]
236
+ y = dst_pts[:, 1]
237
+ [x, y].T = cv_trans * [u, v, 1].T
238
+
239
+ Parameters:
240
+ ----------
241
+ @src_pts: Kx2 np.array
242
+ source points, each row is a pair of coordinates (x, y)
243
+ @dst_pts: Kx2 np.array
244
+ destination points, each row is a pair of transformed
245
+ coordinates (x, y)
246
+ reflective: True or False
247
+ if True:
248
+ use reflective similarity transform
249
+ else:
250
+ use non-reflective similarity transform
251
+
252
+ Returns:
253
+ ----------
254
+ @cv2_trans: 2x3 np.array
255
+ transform matrix from src_pts to dst_pts, could be directly used
256
+ for cv2.warpAffine()
257
+ """
258
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
259
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
260
+
261
+ return cv2_trans
262
+
263
+
264
+ if __name__ == '__main__':
265
+ """
266
+ u = [0, 6, -2]
267
+ v = [0, 3, 5]
268
+ x = [-1, 0, 4]
269
+ y = [-1, -10, 4]
270
+
271
+ # In Matlab, run:
272
+ #
273
+ # uv = [u'; v'];
274
+ # xy = [x'; y'];
275
+ # tform_sim=cp2tform(uv,xy,'similarity');
276
+ #
277
+ # trans = tform_sim.tdata.T
278
+ # ans =
279
+ # -0.0764 -1.6190 0
280
+ # 1.6190 -0.0764 0
281
+ # -3.2156 0.0290 1.0000
282
+ # trans_inv = tform_sim.tdata.Tinv
283
+ # ans =
284
+ #
285
+ # -0.0291 0.6163 0
286
+ # -0.6163 -0.0291 0
287
+ # -0.0756 1.9826 1.0000
288
+ # xy_m=tformfwd(tform_sim, u,v)
289
+ #
290
+ # xy_m =
291
+ #
292
+ # -3.2156 0.0290
293
+ # 1.1833 -9.9143
294
+ # 5.0323 2.8853
295
+ # uv_m=tforminv(tform_sim, x,y)
296
+ #
297
+ # uv_m =
298
+ #
299
+ # 0.5698 1.3953
300
+ # 6.0872 2.2733
301
+ # -2.6570 4.3314
302
+ """
303
+ u = [0, 6, -2]
304
+ v = [0, 3, 5]
305
+ x = [-1, 0, 4]
306
+ y = [-1, -10, 4]
307
+
308
+ uv = np.array((u, v)).T
309
+ xy = np.array((x, y)).T
310
+
311
+ print('\n--->uv:')
312
+ print(uv)
313
+ print('\n--->xy:')
314
+ print(xy)
315
+
316
+ trans, trans_inv = get_similarity_transform(uv, xy)
317
+
318
+ print('\n--->trans matrix:')
319
+ print(trans)
320
+
321
+ print('\n--->trans_inv matrix:')
322
+ print(trans_inv)
323
+
324
+ print('\n---> apply transform to uv')
325
+ print('\nxy_m = uv_augmented * trans')
326
+ uv_aug = np.hstack((
327
+ uv, np.ones((uv.shape[0], 1))
328
+ ))
329
+ xy_m = np.dot(uv_aug, trans)
330
+ print(xy_m)
331
+
332
+ print('\nxy_m = tformfwd(trans, uv)')
333
+ xy_m = tformfwd(trans, uv)
334
+ print(xy_m)
335
+
336
+ print('\n---> apply inverse transform to xy')
337
+ print('\nuv_m = xy_augmented * trans_inv')
338
+ xy_aug = np.hstack((
339
+ xy, np.ones((xy.shape[0], 1))
340
+ ))
341
+ uv_m = np.dot(xy_aug, trans_inv)
342
+ print(uv_m)
343
+
344
+ print('\nuv_m = tformfwd(trans_inv, xy)')
345
+ uv_m = tformfwd(trans_inv, xy)
346
+ print(uv_m)
347
+
348
+ uv_m = tforminv(trans, xy)
349
+ print('\nuv_m = tforminv(trans, xy)')
350
+ print(uv_m)
utils/parser.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ def parse_args(input_args=None):
5
+ parser = argparse.ArgumentParser(description="Train Consistency Encoder.")
6
+ parser.add_argument(
7
+ "--pretrained_model_name_or_path",
8
+ type=str,
9
+ default=None,
10
+ required=True,
11
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
12
+ )
13
+ parser.add_argument(
14
+ "--pretrained_vae_model_name_or_path",
15
+ type=str,
16
+ default=None,
17
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
18
+ )
19
+ parser.add_argument(
20
+ "--revision",
21
+ type=str,
22
+ default=None,
23
+ required=False,
24
+ help="Revision of pretrained model identifier from huggingface.co/models.",
25
+ )
26
+ parser.add_argument(
27
+ "--variant",
28
+ type=str,
29
+ default=None,
30
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
31
+ )
32
+
33
+ # parser.add_argument(
34
+ # "--instance_data_dir",
35
+ # type=str,
36
+ # required=True,
37
+ # help=("A folder containing the training data. "),
38
+ # )
39
+
40
+ parser.add_argument(
41
+ "--data_config_path",
42
+ type=str,
43
+ required=True,
44
+ help=("A folder containing the training data. "),
45
+ )
46
+
47
+ parser.add_argument(
48
+ "--cache_dir",
49
+ type=str,
50
+ default=None,
51
+ help="The directory where the downloaded models and datasets will be stored.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--image_column",
56
+ type=str,
57
+ default="image",
58
+ help="The column of the dataset containing the target image. By "
59
+ "default, the standard Image Dataset maps out 'file_name' "
60
+ "to 'image'.",
61
+ )
62
+ parser.add_argument(
63
+ "--caption_column",
64
+ type=str,
65
+ default=None,
66
+ help="The column of the dataset containing the instance prompt for each image",
67
+ )
68
+
69
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
70
+
71
+ parser.add_argument(
72
+ "--instance_prompt",
73
+ type=str,
74
+ default=None,
75
+ required=True,
76
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
77
+ )
78
+
79
+ parser.add_argument(
80
+ "--validation_prompt",
81
+ type=str,
82
+ default=None,
83
+ help="A prompt that is used during validation to verify that the model is learning.",
84
+ )
85
+ parser.add_argument(
86
+ "--num_train_vis_images",
87
+ type=int,
88
+ default=2,
89
+ help="Number of images that should be generated during validation with `validation_prompt`.",
90
+ )
91
+ parser.add_argument(
92
+ "--num_validation_images",
93
+ type=int,
94
+ default=2,
95
+ help="Number of images that should be generated during validation with `validation_prompt`.",
96
+ )
97
+
98
+ parser.add_argument(
99
+ "--validation_vis_steps",
100
+ type=int,
101
+ default=500,
102
+ help=(
103
+ "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
104
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
105
+ ),
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--train_vis_steps",
110
+ type=int,
111
+ default=500,
112
+ help=(
113
+ "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt"
114
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
115
+ ),
116
+ )
117
+
118
+ parser.add_argument(
119
+ "--vis_lcm",
120
+ type=bool,
121
+ default=True,
122
+ help=(
123
+ "Also log results of LCM inference",
124
+ ),
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--output_dir",
129
+ type=str,
130
+ default="lora-dreambooth-model",
131
+ help="The output directory where the model predictions and checkpoints will be written.",
132
+ )
133
+
134
+ parser.add_argument("--save_only_encoder", action="store_true", help="Only save the encoder and not the full accelerator state")
135
+
136
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
137
+
138
+ parser.add_argument("--freeze_encoder_unet", action="store_true", help="Don't train encoder unet")
139
+ parser.add_argument("--predict_word_embedding", action="store_true", help="Predict word embeddings in addition to KV features")
140
+ parser.add_argument("--ip_adapter_feature_extractor_path", type=str, help="Path to pre-trained feature extractor for IP-adapter")
141
+ parser.add_argument("--ip_adapter_model_path", type=str, help="Path to pre-trained IP-adapter.")
142
+ parser.add_argument("--ip_adapter_tokens", type=int, default=16, help="Number of tokens to use in IP-adapter cross attention mechanism")
143
+ parser.add_argument("--optimize_adapter", action="store_true", help="Optimize IP-adapter parameters (projector + cross-attention layers)")
144
+ parser.add_argument("--adapter_attention_scale", type=float, default=1.0, help="Relative strength of the adapter cross attention layers")
145
+ parser.add_argument("--adapter_lr", type=float, help="Learning rate for the adapter parameters. Defaults to the global LR if not provided")
146
+
147
+ parser.add_argument("--noisy_encoder_input", action="store_true", help="Noise the encoder input to the same step as the decoder?")
148
+
149
+ # related to CFG:
150
+ parser.add_argument("--adapter_drop_chance", type=float, default=0.0, help="Chance to drop adapter condition input during training")
151
+ parser.add_argument("--text_drop_chance", type=float, default=0.0, help="Chance to drop text condition during training")
152
+ parser.add_argument("--kv_drop_chance", type=float, default=0.0, help="Chance to drop KV condition during training")
153
+
154
+
155
+
156
+ parser.add_argument(
157
+ "--resolution",
158
+ type=int,
159
+ default=1024,
160
+ help=(
161
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
162
+ " resolution"
163
+ ),
164
+ )
165
+
166
+ parser.add_argument(
167
+ "--crops_coords_top_left_h",
168
+ type=int,
169
+ default=0,
170
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
171
+ )
172
+
173
+ parser.add_argument(
174
+ "--crops_coords_top_left_w",
175
+ type=int,
176
+ default=0,
177
+ help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
178
+ )
179
+
180
+ parser.add_argument(
181
+ "--center_crop",
182
+ default=False,
183
+ action="store_true",
184
+ help=(
185
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
186
+ " cropped. The images will be resized to the resolution first before cropping."
187
+ ),
188
+ )
189
+
190
+ parser.add_argument(
191
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
192
+ )
193
+
194
+ parser.add_argument("--num_train_epochs", type=int, default=1)
195
+
196
+ parser.add_argument(
197
+ "--max_train_steps",
198
+ type=int,
199
+ default=None,
200
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--checkpointing_steps",
205
+ type=int,
206
+ default=500,
207
+ help=(
208
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
209
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
210
+ " training using `--resume_from_checkpoint`."
211
+ ),
212
+ )
213
+
214
+ parser.add_argument(
215
+ "--checkpoints_total_limit",
216
+ type=int,
217
+ default=5,
218
+ help=("Max number of checkpoints to store."),
219
+ )
220
+
221
+ parser.add_argument(
222
+ "--resume_from_checkpoint",
223
+ type=str,
224
+ default=None,
225
+ help=(
226
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
227
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
228
+ ),
229
+ )
230
+
231
+ parser.add_argument("--max_timesteps_for_x0_loss", type=int, default=1001)
232
+
233
+ parser.add_argument(
234
+ "--gradient_accumulation_steps",
235
+ type=int,
236
+ default=1,
237
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
238
+ )
239
+
240
+ parser.add_argument(
241
+ "--gradient_checkpointing",
242
+ action="store_true",
243
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
244
+ )
245
+
246
+ parser.add_argument(
247
+ "--learning_rate",
248
+ type=float,
249
+ default=1e-4,
250
+ help="Initial learning rate (after the potential warmup period) to use.",
251
+ )
252
+
253
+ parser.add_argument(
254
+ "--scale_lr",
255
+ action="store_true",
256
+ default=False,
257
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
258
+ )
259
+
260
+ parser.add_argument(
261
+ "--lr_scheduler",
262
+ type=str,
263
+ default="constant",
264
+ help=(
265
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
266
+ ' "constant", "constant_with_warmup"]'
267
+ ),
268
+ )
269
+
270
+ parser.add_argument(
271
+ "--snr_gamma",
272
+ type=float,
273
+ default=None,
274
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
275
+ "More details here: https://arxiv.org/abs/2303.09556.",
276
+ )
277
+
278
+ parser.add_argument(
279
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
280
+ )
281
+
282
+ parser.add_argument(
283
+ "--lr_num_cycles",
284
+ type=int,
285
+ default=1,
286
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
287
+ )
288
+
289
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
290
+
291
+ parser.add_argument(
292
+ "--dataloader_num_workers",
293
+ type=int,
294
+ default=0,
295
+ help=(
296
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
297
+ ),
298
+ )
299
+
300
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
301
+
302
+ parser.add_argument(
303
+ "--adam_epsilon",
304
+ type=float,
305
+ default=1e-08,
306
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
307
+ )
308
+
309
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
310
+
311
+ parser.add_argument(
312
+ "--logging_dir",
313
+ type=str,
314
+ default="logs",
315
+ help=(
316
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
317
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
318
+ ),
319
+ )
320
+ parser.add_argument(
321
+ "--allow_tf32",
322
+ action="store_true",
323
+ help=(
324
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
325
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
326
+ ),
327
+ )
328
+
329
+ parser.add_argument(
330
+ "--report_to",
331
+ type=str,
332
+ default="wandb",
333
+ help=(
334
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
335
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
336
+ ),
337
+ )
338
+
339
+ parser.add_argument(
340
+ "--mixed_precision",
341
+ type=str,
342
+ default=None,
343
+ choices=["no", "fp16", "bf16"],
344
+ help=(
345
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
346
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
347
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
348
+ ),
349
+ )
350
+
351
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
352
+
353
+ parser.add_argument(
354
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
355
+ )
356
+
357
+ parser.add_argument(
358
+ "--rank",
359
+ type=int,
360
+ default=4,
361
+ help=("The dimension of the LoRA update matrices."),
362
+ )
363
+
364
+ parser.add_argument(
365
+ "--pretrained_lcm_lora_path",
366
+ type=str,
367
+ default="latent-consistency/lcm-lora-sdxl",
368
+ help=("Path for lcm lora pretrained"),
369
+ )
370
+
371
+ parser.add_argument(
372
+ "--losses_config_path",
373
+ type=str,
374
+ required=True,
375
+ help=("A yaml file containing losses to use and their weights."),
376
+ )
377
+
378
+ parser.add_argument(
379
+ "--lcm_every_k_steps",
380
+ type=int,
381
+ default=-1,
382
+ help="How often to run lcm. If -1, lcm is not run."
383
+ )
384
+
385
+ parser.add_argument(
386
+ "--lcm_batch_size",
387
+ type=int,
388
+ default=1,
389
+ help="Batch size for lcm."
390
+ )
391
+ parser.add_argument(
392
+ "--lcm_max_timestep",
393
+ type=int,
394
+ default=1000,
395
+ help="Max timestep to use with LCM."
396
+ )
397
+
398
+ parser.add_argument(
399
+ "--lcm_sample_scale_every_k_steps",
400
+ type=int,
401
+ default=-1,
402
+ help="How often to change lcm scale. If -1, scale is fixed at 1."
403
+ )
404
+
405
+ parser.add_argument(
406
+ "--lcm_min_scale",
407
+ type=float,
408
+ default=0.1,
409
+ help="When sampling lcm scale, the minimum scale to use."
410
+ )
411
+
412
+ parser.add_argument(
413
+ "--scale_lcm_by_max_step",
414
+ action="store_true",
415
+ help="scale LCM lora alpha linearly by the maximal timestep sampled that iteration"
416
+ )
417
+
418
+ parser.add_argument(
419
+ "--lcm_sample_full_lcm_prob",
420
+ type=float,
421
+ default=0.2,
422
+ help="When sampling lcm scale, the probability of using full lcm (scale of 1)."
423
+ )
424
+
425
+ parser.add_argument(
426
+ "--run_on_cpu",
427
+ action="store_true",
428
+ help="whether to run on cpu or not"
429
+ )
430
+
431
+ parser.add_argument(
432
+ "--experiment_name",
433
+ type=str,
434
+ help=("A short description of the experiment to add to the wand run log. "),
435
+ )
436
+ parser.add_argument("--encoder_lora_rank", type=int, default=0, help="Rank of Lora in unet encoder. 0 means no lora")
437
+
438
+ parser.add_argument("--kvcopy_lora_rank", type=int, default=0, help="Rank of lora in the kvcopy modules. 0 means no lora")
439
+
440
+
441
+ if input_args is not None:
442
+ args = parser.parse_args(input_args)
443
+ else:
444
+ args = parser.parse_args()
445
+
446
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
447
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
448
+ args.local_rank = env_local_rank
449
+
450
+ args.optimizer = "AdamW"
451
+
452
+ return args