qninhdt commited on
Commit
bf63092
1 Parent(s): 9aa02cb

Upload 53 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/dog_sit.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/dog.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/Teaser.png filter=lfs diff=lfs merge=lfs -text
40
+ multi_image/assets/realdog.gif filter=lfs diff=lfs merge=lfs -text
LICENSE.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0 
2
+  
3
+ Copyright 2023 S-Lab
4
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
8
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
9
+ 4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
10
+  
README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <h1 align="center">DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing</h1>
3
+ <h3 align="center">CVPR 2024</h3>
4
+ <p align="center">
5
+ <a href="https://kevin-thu.github.io/homepage/"><strong>Kaiwen Zhang</strong></a>
6
+ &nbsp;&nbsp;
7
+ <a href="https://zhouyifan.net/about/"><strong>Yifan Zhou</strong></a>
8
+ &nbsp;&nbsp;
9
+ <a href="https://sheldontsui.github.io/"><strong>Xudong Xu</strong></a>
10
+ &nbsp;&nbsp;
11
+ <a href="https://xingangpan.github.io/"><strong>Xingang Pan<sep>✉</sep></strong></a>
12
+ &nbsp;&nbsp;
13
+ <a href="http://daibo.info/"><strong>Bo Dai</strong></a>
14
+ </p>
15
+ <br>
16
+
17
+ <p align="center">
18
+ <sep>✉</sep>Corresponding Author
19
+ </p>
20
+
21
+ <div align="center">
22
+ <img src="./assets/teaser.gif", width="500">
23
+ </div>
24
+
25
+ <p align="center">
26
+ <a href="https://arxiv.org/abs/2312.07409"><img alt='arXiv' src="https://img.shields.io/badge/arXiv-2312.07409-b31b1b.svg"></a>
27
+ <a href="https://kevin-thu.github.io/DiffMorpher_page/"><img alt='page' src="https://img.shields.io/badge/Project-Website-orange"></a>
28
+ <a href="https://twitter.com/sze68zkw"><img alt='Twitter' src="https://img.shields.io/twitter/follow/sze68zkw?label=%40KaiwenZhang"></a>
29
+ <a href="https://twitter.com/XingangP"><img alt='Twitter' src="https://img.shields.io/twitter/follow/XingangP?label=%40XingangPan"></a>
30
+ </p>
31
+ <br>
32
+ </p>
33
+
34
+ ## Web Demos
35
+
36
+ [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/KaiwenZhang/DiffMorpher)
37
+
38
+ <p align="left">
39
+ <a href="https://huggingface.co/spaces/Kevin-thu/DiffMorpher"><img alt="Huggingface" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DiffMorpher-orange"></a>
40
+ </p>
41
+
42
+ <!-- Great thanks to [OpenXLab](https://openxlab.org.cn/home) for the NVIDIA A100 GPU support! -->
43
+
44
+ ## Requirements
45
+ To install the requirements, run the following in your environment first:
46
+ ```bash
47
+ pip install -r requirements.txt
48
+ ```
49
+ To run the code with CUDA properly, you can comment out `torch` and `torchvision` in `requirement.txt`, and install the appropriate version of `torch` and `torchvision` according to the instructions on [PyTorch](https://pytorch.org/get-started/locally/).
50
+
51
+ You can also download the pretrained model *Stable Diffusion v2.1-base* from [Huggingface](https://huggingface.co/stabilityai/stable-diffusion-2-1-base), and specify the `model_path` to your local directory.
52
+
53
+ ## Run Gradio UI
54
+ To start the Gradio UI of DiffMorpher, run the following in your environment:
55
+ ```bash
56
+ python app.py
57
+ ```
58
+ Then, by default, you can access the UI at [http://127.0.0.1:7860](http://127.0.0.1:7860).
59
+
60
+ ## Run the code
61
+ You can also run the code with the following command:
62
+ ```bash
63
+ python main.py \
64
+ --image_path_0 [image_path_0] --image_path_1 [image_path_1] \
65
+ --prompt_0 [prompt_0] --prompt_1 [prompt_1] \
66
+ --output_path [output_path] \
67
+ --use_adain --use_reschedule --save_inter
68
+ ```
69
+ The script also supports the following options:
70
+
71
+ - `--image_path_0`: Path of the first image (default: "")
72
+ - `--prompt_0`: Prompt of the first image (default: "")
73
+ - `--image_path_1`: Path of the second image (default: "")
74
+ - `--prompt_1`: Prompt of the second image (default: "")
75
+ - `--model_path`: Pretrained model path (default: "stabilityai/stable-diffusion-2-1-base")
76
+ - `--output_path`: Path of the output image (default: "")
77
+ - `--save_lora_dir`: Path of the output lora directory (default: "./lora")
78
+ - `--load_lora_path_0`: Path of the lora directory of the first image (default: "")
79
+ - `--load_lora_path_1`: Path of the lora directory of the second image (default: "")
80
+ - `--use_adain`: Use AdaIN (default: False)
81
+ - `--use_reschedule`: Use reschedule sampling (default: False)
82
+ - `--lamb`: Hyperparameter $\lambda \in [0,1]$ for self-attention replacement, where a larger $\lambda$ indicates more replacements (default: 0.6)
83
+ - `--fix_lora_value`: Fix lora value (default: LoRA Interpolation, not fixed)
84
+ - `--save_inter`: Save intermediate results (default: False)
85
+ - `--num_frames`: Number of frames to generate (default: 50)
86
+ - `--duration`: Duration of each frame (default: 50)
87
+
88
+ Examples:
89
+ ```bash
90
+ python main.py \
91
+ --image_path_0 ./assets/Trump.jpg --image_path_1 ./assets/Biden.jpg \
92
+ --prompt_0 "A photo of an American man" --prompt_1 "A photo of an American man" \
93
+ --output_path "./results/Trump_Biden" \
94
+ --use_adain --use_reschedule --save_inter
95
+ ```
96
+
97
+ ```bash
98
+ python main.py \
99
+ --image_path_0 ./assets/vangogh.jpg --image_path_1 ./assets/pearlgirl.jpg \
100
+ --prompt_0 "An oil painting of a man" --prompt_1 "An oil painting of a woman" \
101
+ --output_path "./results/vangogh_pearlgirl" \
102
+ --use_adain --use_reschedule --save_inter
103
+ ```
104
+
105
+ ```bash
106
+ python main.py \
107
+ --image_path_0 ./assets/lion.png --image_path_1 ./assets/tiger.png \
108
+ --prompt_0 "A photo of a lion" --prompt_1 "A photo of a tiger" \
109
+ --output_path "./results/lion_tiger" \
110
+ --use_adain --use_reschedule --save_inter
111
+ ```
112
+
113
+ ## MorphBench
114
+ To evaluate the effectiveness of our methods, we present *MorphBench*, the first benchmark dataset for assessing image morphing of general objects. You can download the dataset from [Google Drive](https://drive.google.com/file/d/1NWPzJhOgP-udP_wYbd0selRG4cu8xsu4/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1J3xE3OJdEhKyoc1QObyYaA?pwd=putk).
115
+
116
+
117
+ ## License
118
+ The code related to the DiffMorpher algorithm is licensed under [LICENSE](LICENSE.txt).
119
+
120
+ However, this project is mostly built on the open-sourse library [diffusers](https://github.com/huggingface/diffusers), which is under a separate license terms [Apache License 2.0](https://github.com/huggingface/diffusers/blob/main/LICENSE). (Cheers to the community as well!)
121
+
122
+ ## Citation
123
+
124
+ ```bibtex
125
+ @article{zhang2023diffmorpher,
126
+ title={DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing},
127
+ author={Zhang, Kaiwen and Zhou, Yifan and Xu, Xudong and Pan, Xingang and Dai, Bo},
128
+ journal={arXiv preprint arXiv:2312.07409},
129
+ year={2023}
130
+ }
131
+ ```
app.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from datetime import datetime
8
+ from model import DiffMorpherPipeline
9
+ from utils.lora_utils import train_lora
10
+
11
+ LENGTH=450
12
+
13
+ def train_lora_interface(
14
+ image,
15
+ prompt,
16
+ model_path,
17
+ output_path,
18
+ lora_steps,
19
+ lora_rank,
20
+ lora_lr,
21
+ num
22
+ ):
23
+ os.makedirs(output_path, exist_ok=True)
24
+ train_lora(image, prompt, output_path, model_path,
25
+ lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_{num}.ckpt", progress=gr.Progress())
26
+ return f"Train LoRA {'A' if num == 0 else 'B'} Done!"
27
+
28
+ def run_diffmorpher(
29
+ image_0,
30
+ image_1,
31
+ prompt_0,
32
+ prompt_1,
33
+ model_path,
34
+ lora_mode,
35
+ lamb,
36
+ use_adain,
37
+ use_reschedule,
38
+ num_frames,
39
+ fps,
40
+ save_inter,
41
+ load_lora_path_0,
42
+ load_lora_path_1,
43
+ output_path
44
+ ):
45
+ run_id = datetime.now().strftime("%H%M") + "_" + datetime.now().strftime("%Y%m%d")
46
+ os.makedirs(output_path, exist_ok=True)
47
+ morpher_pipeline = DiffMorpherPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cuda")
48
+ if lora_mode == "Fix LoRA A":
49
+ fix_lora = 0
50
+ elif lora_mode == "Fix LoRA B":
51
+ fix_lora = 1
52
+ else:
53
+ fix_lora = None
54
+ if not load_lora_path_0:
55
+ load_lora_path_0 = f"{output_path}/lora_0.ckpt"
56
+ if not load_lora_path_1:
57
+ load_lora_path_1 = f"{output_path}/lora_1.ckpt"
58
+ images = morpher_pipeline(
59
+ img_0=image_0,
60
+ img_1=image_1,
61
+ prompt_0=prompt_0,
62
+ prompt_1=prompt_1,
63
+ load_lora_path_0=load_lora_path_0,
64
+ load_lora_path_1=load_lora_path_1,
65
+ lamb=lamb,
66
+ use_adain=use_adain,
67
+ use_reschedule=use_reschedule,
68
+ num_frames=num_frames,
69
+ fix_lora=fix_lora,
70
+ save_intermediates=save_inter,
71
+ progress=gr.Progress()
72
+ )
73
+ video_path = f"{output_path}/{run_id}.mp4"
74
+ video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (512, 512))
75
+ for i, image in enumerate(images):
76
+ # image.save(f"{output_path}/{i}.png")
77
+ video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
78
+ video.release()
79
+ cv2.destroyAllWindows()
80
+ return gr.Video(value=video_path, format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False)
81
+
82
+ def run_all(
83
+ image_0,
84
+ image_1,
85
+ prompt_0,
86
+ prompt_1,
87
+ model_path,
88
+ lora_mode,
89
+ lamb,
90
+ use_adain,
91
+ use_reschedule,
92
+ num_frames,
93
+ fps,
94
+ save_inter,
95
+ load_lora_path_0,
96
+ load_lora_path_1,
97
+ output_path,
98
+ lora_steps,
99
+ lora_rank,
100
+ lora_lr
101
+ ):
102
+ os.makedirs(output_path, exist_ok=True)
103
+ train_lora(image_0, prompt_0, output_path, model_path,
104
+ lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_0.ckpt", progress=gr.Progress())
105
+ train_lora(image_1, prompt_1, output_path, model_path,
106
+ lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_1.ckpt", progress=gr.Progress())
107
+ return run_diffmorpher(
108
+ image_0,
109
+ image_1,
110
+ prompt_0,
111
+ prompt_1,
112
+ model_path,
113
+ lora_mode,
114
+ lamb,
115
+ use_adain,
116
+ use_reschedule,
117
+ num_frames,
118
+ fps,
119
+ save_inter,
120
+ load_lora_path_0,
121
+ load_lora_path_1,
122
+ output_path
123
+ )
124
+
125
+ with gr.Blocks() as demo:
126
+
127
+ with gr.Row():
128
+ gr.Markdown("""
129
+ # Official Implementation of [DiffMorpher](https://kevin-thu.github.io/DiffMorpher_page/)
130
+ """)
131
+
132
+ original_image_0, original_image_1 = gr.State(Image.open("assets/Trump.jpg").convert("RGB").resize((512,512), Image.BILINEAR)), gr.State(Image.open("assets/Biden.jpg").convert("RGB").resize((512,512), Image.BILINEAR))
133
+ # key_points_0, key_points_1 = gr.State([]), gr.State([])
134
+ # to_change_points = gr.State([])
135
+
136
+ with gr.Row():
137
+ with gr.Column():
138
+ input_img_0 = gr.Image(type="numpy", label="Input image A", value="assets/Trump.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
139
+ prompt_0 = gr.Textbox(label="Prompt for image A", value="a photo of an American man", interactive=True)
140
+ with gr.Row():
141
+ train_lora_0_button = gr.Button("Train LoRA A")
142
+ train_lora_1_button = gr.Button("Train LoRA B")
143
+ # show_correspond_button = gr.Button("Show correspondence points")
144
+ with gr.Column():
145
+ input_img_1 = gr.Image(type="numpy", label="Input image B ", value="assets/Biden.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
146
+ prompt_1 = gr.Textbox(label="Prompt for image B", value="a photo of an American man", interactive=True)
147
+ with gr.Row():
148
+ clear_button = gr.Button("Clear All")
149
+ run_button = gr.Button("Run w/o LoRA training")
150
+ with gr.Column():
151
+ output_video = gr.Video(format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False)
152
+ lora_progress_bar = gr.Textbox(label="Display LoRA training progress", interactive=False)
153
+ run_all_button = gr.Button("Run!")
154
+ # with gr.Column():
155
+ # output_video = gr.Video(label="Output video", show_label=True, height=LENGTH, width=LENGTH)
156
+
157
+ with gr.Row():
158
+ gr.Markdown("""
159
+ ### Usage:
160
+ 1. Upload two images (with correspondence) and fill out the prompts.
161
+ (It's recommended to change `[Output path]` accordingly.)
162
+ 2. Click **"Run!"**
163
+
164
+ Or:
165
+ 1. Upload two images (with correspondence) and fill out the prompts.
166
+ 2. Click the **"Train LoRA A/B"** button to fit two LoRAs for two images respectively. <br> &nbsp;&nbsp;
167
+ If you have trained LoRA A or LoRA B before, you can skip the step and fill the specific LoRA path in LoRA settings. <br> &nbsp;&nbsp;
168
+ Trained LoRAs are saved to `[Output Path]/lora_0.ckpt` and `[Output Path]/lora_1.ckpt` by default.
169
+ 3. You might also change the settings below.
170
+ 4. Click **"Run w/o LoRA training"**
171
+
172
+ ### Note:
173
+ 1. To speed up the generation process, you can **ruduce the number of frames** or **turn off "Use Reschedule"**.
174
+ 2. You can try the influence of different prompts. It seems that using the same prompts or aligned prompts works better.
175
+ ### Have fun!
176
+ """)
177
+
178
+ with gr.Accordion(label="Algorithm Parameters"):
179
+ with gr.Tab("Basic Settings"):
180
+ with gr.Row():
181
+ # local_models_dir = 'local_pretrained_models'
182
+ # local_models_choice = \
183
+ # [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
184
+ model_path = gr.Text(value="stabilityai/stable-diffusion-2-1-base",
185
+ label="Diffusion Model Path", interactive=True
186
+ )
187
+ lamb = gr.Slider(value=0.6, minimum=0, maximum=1, step=0.1, label="Lambda for attention replacement", interactive=True)
188
+ lora_mode = gr.Dropdown(value="LoRA Interp",
189
+ label="LoRA Interp. or Fix LoRA",
190
+ choices=["LoRA Interp", "Fix LoRA A", "Fix LoRA B"],
191
+ interactive=True
192
+ )
193
+ use_adain = gr.Checkbox(value=True, label="Use AdaIN", interactive=True)
194
+ use_reschedule = gr.Checkbox(value=True, label="Use Reschedule", interactive=True)
195
+ with gr.Row():
196
+ num_frames = gr.Number(value=16, minimum=0, label="Number of Frames", precision=0, interactive=True)
197
+ fps = gr.Number(value=8, minimum=0, label="FPS (Frame rate)", precision=0, interactive=True)
198
+ save_inter = gr.Checkbox(value=False, label="Save Intermediate Images", interactive=True)
199
+ output_path = gr.Text(value="./results", label="Output Path", interactive=True)
200
+
201
+ with gr.Tab("LoRA Settings"):
202
+ with gr.Row():
203
+ lora_steps = gr.Number(value=200, label="LoRA training steps", precision=0, interactive=True)
204
+ lora_lr = gr.Number(value=0.0002, label="LoRA learning rate", interactive=True)
205
+ lora_rank = gr.Number(value=16, label="LoRA rank", precision=0, interactive=True)
206
+ # save_lora_dir = gr.Text(value="./lora", label="LoRA model save path", interactive=True)
207
+ load_lora_path_0 = gr.Text(value="", label="LoRA model load path for image A", interactive=True)
208
+ load_lora_path_1 = gr.Text(value="", label="LoRA model load path for image B", interactive=True)
209
+
210
+ def store_img(img):
211
+ image = Image.fromarray(img).convert("RGB").resize((512,512), Image.BILINEAR)
212
+ # resize the input to 512x512
213
+ # image = image.resize((512,512), Image.BILINEAR)
214
+ # image = np.array(image)
215
+ # when new image is uploaded, `selected_points` should be empty
216
+ return image
217
+ input_img_0.upload(
218
+ store_img,
219
+ [input_img_0],
220
+ [original_image_0]
221
+ )
222
+ input_img_1.upload(
223
+ store_img,
224
+ [input_img_1],
225
+ [original_image_1]
226
+ )
227
+
228
+ def clear(LENGTH):
229
+ return gr.Image.update(value=None, width=LENGTH, height=LENGTH), \
230
+ gr.Image.update(value=None, width=LENGTH, height=LENGTH), \
231
+ None, None, None, None
232
+ clear_button.click(
233
+ clear,
234
+ [gr.Number(value=LENGTH, visible=False, precision=0)],
235
+ [input_img_0, input_img_1, original_image_0, original_image_1, prompt_0, prompt_1]
236
+ )
237
+
238
+ train_lora_0_button.click(
239
+ train_lora_interface,
240
+ [
241
+ original_image_0,
242
+ prompt_0,
243
+ model_path,
244
+ output_path,
245
+ lora_steps,
246
+ lora_rank,
247
+ lora_lr,
248
+ gr.Number(value=0, visible=False, precision=0)
249
+ ],
250
+ [lora_progress_bar]
251
+ )
252
+
253
+ train_lora_1_button.click(
254
+ train_lora_interface,
255
+ [
256
+ original_image_1,
257
+ prompt_1,
258
+ model_path,
259
+ output_path,
260
+ lora_steps,
261
+ lora_rank,
262
+ lora_lr,
263
+ gr.Number(value=1, visible=False, precision=0)
264
+ ],
265
+ [lora_progress_bar]
266
+ )
267
+
268
+ run_button.click(
269
+ run_diffmorpher,
270
+ [
271
+ original_image_0,
272
+ original_image_1,
273
+ prompt_0,
274
+ prompt_1,
275
+ model_path,
276
+ lora_mode,
277
+ lamb,
278
+ use_adain,
279
+ use_reschedule,
280
+ num_frames,
281
+ fps,
282
+ save_inter,
283
+ load_lora_path_0,
284
+ load_lora_path_1,
285
+ output_path
286
+ ],
287
+ [output_video]
288
+ )
289
+
290
+ run_all_button.click(
291
+ run_all,
292
+ [
293
+ original_image_0,
294
+ original_image_1,
295
+ prompt_0,
296
+ prompt_1,
297
+ model_path,
298
+ lora_mode,
299
+ lamb,
300
+ use_adain,
301
+ use_reschedule,
302
+ num_frames,
303
+ fps,
304
+ save_inter,
305
+ load_lora_path_0,
306
+ load_lora_path_1,
307
+ output_path,
308
+ lora_steps,
309
+ lora_rank,
310
+ lora_lr
311
+ ],
312
+ [output_video]
313
+ )
314
+
315
+ demo.queue().launch(debug=True)
assets/Biden.jpg ADDED
assets/Feifei.jpg ADDED
assets/Musk.jpg ADDED
assets/Teaser.png ADDED

Git LFS Details

  • SHA256: 5aadae7c6c1a0a6b36a91fbf3058bf0f699cfba252dd78e6595343ec5f5a5a08
  • Pointer size: 132 Bytes
  • Size of remote file: 5.7 MB
assets/Trump.jpg ADDED
assets/cat.png ADDED
assets/dog.png ADDED

Git LFS Details

  • SHA256: 20f07d4f4e6c207426d516ddc3662572436a5a5d59cf85e47ca0d39d3a1cd252
  • Pointer size: 132 Bytes
  • Size of remote file: 1.56 MB
assets/dog_sit.png ADDED

Git LFS Details

  • SHA256: 754331bc083116d027e69e0177b3556a1cf5103c94a14d1fae292eb2a768d5b9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
assets/drag_realgirl0.png ADDED
assets/drag_realgirl1.png ADDED
assets/drag_sculp0.png ADDED
assets/drag_sculp1.png ADDED
assets/fuji_0.jpg ADDED
assets/fuji_1.jpg ADDED
assets/house0.jpg ADDED
assets/house1.jpg ADDED
assets/jeep.jpg ADDED
assets/leo_0.jpg ADDED
assets/leo_1.jpg ADDED
assets/lion.png ADDED
assets/man_paint.png ADDED
assets/mit.jpg ADDED
assets/monalisa.jpeg ADDED
assets/obama.jpg ADDED
assets/pearlgirl.jpg ADDED
assets/rabbit.png ADDED
assets/sculp0.png ADDED
assets/sculp1.png ADDED
assets/teaser.gif ADDED

Git LFS Details

  • SHA256: 1588453d2980b0a25c64ba02429493a5b2181e0547b268386dea93b8163a8e51
  • Pointer size: 133 Bytes
  • Size of remote file: 26.7 MB
assets/thu.jpg ADDED
assets/tiger.png ADDED
assets/van.jpg ADDED
assets/vangogh.jpg ADDED
assets/vangogh_hat.png ADDED
assets/wave_paint.png ADDED
assets/wave_real.jpg ADDED
main.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from argparse import ArgumentParser
7
+ from model import DiffMorpherPipeline
8
+
9
+ parser = ArgumentParser()
10
+ parser.add_argument(
11
+ "--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base",
12
+ help="Pretrained model to use (default: %(default)s)"
13
+ )
14
+ parser.add_argument(
15
+ "--image_path_0", type=str, default="",
16
+ help="Path of the first image (default: %(default)s)")
17
+ parser.add_argument(
18
+ "--prompt_0", type=str, default="",
19
+ help="Prompt of the second image (default: %(default)s)")
20
+ parser.add_argument(
21
+ "--image_path_1", type=str, default="",
22
+ help="Path of the first image (default: %(default)s)")
23
+ parser.add_argument(
24
+ "--prompt_1", type=str, default="",
25
+ help="Prompt of the second image (default: %(default)s)")
26
+ parser.add_argument(
27
+ "--output_path", type=str, default="./results",
28
+ help="Path of the output image (default: %(default)s)"
29
+ )
30
+ parser.add_argument(
31
+ "--save_lora_dir", type=str, default="./lora",
32
+ help="Path of the output lora directory (default: %(default)s)"
33
+ )
34
+ parser.add_argument(
35
+ "--load_lora_path_0", type=str, default="",
36
+ help="Path of the lora directory of the first image (default: %(default)s)"
37
+ )
38
+ parser.add_argument(
39
+ "--load_lora_path_1", type=str, default="",
40
+ help="Path of the lora directory of the second image (default: %(default)s)"
41
+ )
42
+ parser.add_argument(
43
+ "--use_adain", action="store_true",
44
+ help="Use AdaIN (default: %(default)s)"
45
+ )
46
+ parser.add_argument(
47
+ "--use_reschedule", action="store_true",
48
+ help="Use reschedule sampling (default: %(default)s)"
49
+ )
50
+ parser.add_argument(
51
+ "--lamb", type=float, default=0.6,
52
+ help="Lambda for self-attention replacement (default: %(default)s)"
53
+ )
54
+ parser.add_argument(
55
+ "--fix_lora_value", type=float, default=None,
56
+ help="Fix lora value (default: LoRA Interp., not fixed)"
57
+ )
58
+ parser.add_argument(
59
+ "--save_inter", action="store_true",
60
+ help="Save intermediate results (default: %(default)s)"
61
+ )
62
+ parser.add_argument(
63
+ "--num_frames", type=int, default=16,
64
+ help="Number of frames to generate (default: %(default)s)"
65
+ )
66
+ parser.add_argument(
67
+ "--duration", type=int, default=100,
68
+ help="Duration of each frame (default: %(default)s ms)"
69
+ )
70
+ parser.add_argument(
71
+ "--no_lora", action="store_true"
72
+ )
73
+
74
+ args = parser.parse_args()
75
+
76
+ os.makedirs(args.output_path, exist_ok=True)
77
+ pipeline = DiffMorpherPipeline.from_pretrained(
78
+ args.model_path, torch_dtype=torch.float32)
79
+ pipeline.to("cuda")
80
+ images = pipeline(
81
+ img_path_0=args.image_path_0,
82
+ img_path_1=args.image_path_1,
83
+ prompt_0=args.prompt_0,
84
+ prompt_1=args.prompt_1,
85
+ save_lora_dir=args.save_lora_dir,
86
+ load_lora_path_0=args.load_lora_path_0,
87
+ load_lora_path_1=args.load_lora_path_1,
88
+ use_adain=args.use_adain,
89
+ use_reschedule=args.use_reschedule,
90
+ lamd=args.lamb,
91
+ output_path=args.output_path,
92
+ num_frames=args.num_frames,
93
+ fix_lora=args.fix_lora_value,
94
+ save_intermediates=args.save_inter,
95
+ use_lora=not args.no_lora
96
+ )
97
+ images[0].save(f"{args.output_path}/output.gif", save_all=True,
98
+ append_images=images[1:], duration=args.duration, loop=0)
model.py ADDED
@@ -0,0 +1,639 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
3
+ from diffusers.models.attention_processor import AttnProcessor
4
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
5
+ from diffusers.schedulers import KarrasDiffusionSchedulers
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import tqdm
9
+ import numpy as np
10
+ import safetensors
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
14
+ from diffusers import StableDiffusionPipeline
15
+ from argparse import ArgumentParser
16
+ import inspect
17
+
18
+ from utils.model_utils import get_img, slerp, do_replace_attn
19
+ from utils.lora_utils import train_lora, load_lora
20
+ from utils.alpha_scheduler import AlphaScheduler
21
+
22
+
23
+ class StoreProcessor():
24
+ def __init__(self, original_processor, value_dict, name):
25
+ self.original_processor = original_processor
26
+ self.value_dict = value_dict
27
+ self.name = name
28
+ self.value_dict[self.name] = dict()
29
+ self.id = 0
30
+
31
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
32
+ # Is self attention
33
+ if encoder_hidden_states is None:
34
+ self.value_dict[self.name][self.id] = hidden_states.detach()
35
+ self.id += 1
36
+ res = self.original_processor(attn, hidden_states, *args,
37
+ encoder_hidden_states=encoder_hidden_states,
38
+ attention_mask=attention_mask,
39
+ **kwargs)
40
+
41
+ return res
42
+
43
+
44
+ class LoadProcessor():
45
+ def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6):
46
+ super().__init__()
47
+ self.original_processor = original_processor
48
+ self.name = name
49
+ self.img0_dict = img0_dict
50
+ self.img1_dict = img1_dict
51
+ self.alpha = alpha
52
+ self.beta = beta
53
+ self.lamd = lamd
54
+ self.id = 0
55
+
56
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
57
+ # Is self attention
58
+ if encoder_hidden_states is None:
59
+ if self.id < 50 * self.lamd:
60
+ map0 = self.img0_dict[self.name][self.id]
61
+ map1 = self.img1_dict[self.name][self.id]
62
+ cross_map = self.beta * hidden_states + \
63
+ (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
64
+ # cross_map = self.beta * hidden_states + \
65
+ # (1 - self.beta) * slerp(map0, map1, self.alpha)
66
+ # cross_map = slerp(slerp(map0, map1, self.alpha),
67
+ # hidden_states, self.beta)
68
+ # cross_map = hidden_states
69
+ # cross_map = torch.cat(
70
+ # ((1 - self.alpha) * map0, self.alpha * map1), dim=1)
71
+
72
+ res = self.original_processor(attn, hidden_states, *args,
73
+ encoder_hidden_states=cross_map,
74
+ attention_mask=attention_mask,
75
+ **kwargs)
76
+ else:
77
+ res = self.original_processor(attn, hidden_states, *args,
78
+ encoder_hidden_states=encoder_hidden_states,
79
+ attention_mask=attention_mask,
80
+ **kwargs)
81
+
82
+ self.id += 1
83
+ # if self.id == len(self.img0_dict[self.name]):
84
+ if self.id == len(self.img0_dict[self.name]):
85
+ self.id = 0
86
+ else:
87
+ res = self.original_processor(attn, hidden_states, *args,
88
+ encoder_hidden_states=encoder_hidden_states,
89
+ attention_mask=attention_mask,
90
+ **kwargs)
91
+
92
+ return res
93
+
94
+
95
+ class DiffMorpherPipeline(StableDiffusionPipeline):
96
+
97
+ def __init__(self,
98
+ vae: AutoencoderKL,
99
+ text_encoder: CLIPTextModel,
100
+ tokenizer: CLIPTokenizer,
101
+ unet: UNet2DConditionModel,
102
+ scheduler: KarrasDiffusionSchedulers,
103
+ safety_checker: StableDiffusionSafetyChecker,
104
+ feature_extractor: CLIPImageProcessor,
105
+ image_encoder=None,
106
+ requires_safety_checker: bool = True,
107
+ ):
108
+ sig = inspect.signature(super().__init__)
109
+ params = sig.parameters
110
+ if 'image_encoder' in params:
111
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
112
+ safety_checker, feature_extractor, image_encoder, requires_safety_checker)
113
+ else:
114
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
115
+ safety_checker, feature_extractor, requires_safety_checker)
116
+ self.img0_dict = dict()
117
+ self.img1_dict = dict()
118
+
119
+ def inv_step(
120
+ self,
121
+ model_output: torch.FloatTensor,
122
+ timestep: int,
123
+ x: torch.FloatTensor,
124
+ eta=0.,
125
+ verbose=False
126
+ ):
127
+ """
128
+ Inverse sampling for DDIM Inversion
129
+ """
130
+ if verbose:
131
+ print("timestep: ", timestep)
132
+ next_step = timestep
133
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps //
134
+ self.scheduler.num_inference_steps, 999)
135
+ alpha_prod_t = self.scheduler.alphas_cumprod[
136
+ timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
137
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
138
+ beta_prod_t = 1 - alpha_prod_t
139
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
140
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
141
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
142
+ return x_next, pred_x0
143
+
144
+ @torch.no_grad()
145
+ def invert(
146
+ self,
147
+ image: torch.Tensor,
148
+ prompt,
149
+ num_inference_steps=50,
150
+ num_actual_inference_steps=None,
151
+ guidance_scale=1.,
152
+ eta=0.0,
153
+ **kwds):
154
+ """
155
+ invert a real image into noise map with determinisc DDIM inversion
156
+ """
157
+ DEVICE = torch.device(
158
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
159
+ batch_size = image.shape[0]
160
+ if isinstance(prompt, list):
161
+ if batch_size == 1:
162
+ image = image.expand(len(prompt), -1, -1, -1)
163
+ elif isinstance(prompt, str):
164
+ if batch_size > 1:
165
+ prompt = [prompt] * batch_size
166
+
167
+ # text embeddings
168
+ text_input = self.tokenizer(
169
+ prompt,
170
+ padding="max_length",
171
+ max_length=77,
172
+ return_tensors="pt"
173
+ )
174
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
175
+ print("input text embeddings :", text_embeddings.shape)
176
+ # define initial latents
177
+ latents = self.image2latent(image)
178
+
179
+ # unconditional embedding for classifier free guidance
180
+ if guidance_scale > 1.:
181
+ max_length = text_input.input_ids.shape[-1]
182
+ unconditional_input = self.tokenizer(
183
+ [""] * batch_size,
184
+ padding="max_length",
185
+ max_length=77,
186
+ return_tensors="pt"
187
+ )
188
+ unconditional_embeddings = self.text_encoder(
189
+ unconditional_input.input_ids.to(DEVICE))[0]
190
+ text_embeddings = torch.cat(
191
+ [unconditional_embeddings, text_embeddings], dim=0)
192
+
193
+ print("latents shape: ", latents.shape)
194
+ # interative sampling
195
+ self.scheduler.set_timesteps(num_inference_steps)
196
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
197
+ # print("attributes: ", self.scheduler.__dict__)
198
+ latents_list = [latents]
199
+ pred_x0_list = [latents]
200
+ for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
201
+ if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
202
+ continue
203
+
204
+ if guidance_scale > 1.:
205
+ model_inputs = torch.cat([latents] * 2)
206
+ else:
207
+ model_inputs = latents
208
+
209
+ # predict the noise
210
+ noise_pred = self.unet(
211
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
212
+ if guidance_scale > 1.:
213
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
214
+ noise_pred = noise_pred_uncon + guidance_scale * \
215
+ (noise_pred_con - noise_pred_uncon)
216
+ # compute the previous noise sample x_t-1 -> x_t
217
+ latents, pred_x0 = self.inv_step(noise_pred, t, latents)
218
+ latents_list.append(latents)
219
+ pred_x0_list.append(pred_x0)
220
+
221
+ return latents
222
+
223
+ @torch.no_grad()
224
+ def ddim_inversion(self, latent, cond):
225
+ timesteps = reversed(self.scheduler.timesteps)
226
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
227
+ for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")):
228
+ cond_batch = cond.repeat(latent.shape[0], 1, 1)
229
+
230
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
231
+ alpha_prod_t_prev = (
232
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
233
+ if i > 0 else self.scheduler.final_alpha_cumprod
234
+ )
235
+
236
+ mu = alpha_prod_t ** 0.5
237
+ mu_prev = alpha_prod_t_prev ** 0.5
238
+ sigma = (1 - alpha_prod_t) ** 0.5
239
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
240
+
241
+ eps = self.unet(
242
+ latent, t, encoder_hidden_states=cond_batch).sample
243
+
244
+ pred_x0 = (latent - sigma_prev * eps) / mu_prev
245
+ latent = mu * pred_x0 + sigma * eps
246
+ # if save_latents:
247
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
248
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
249
+ return latent
250
+
251
+ def step(
252
+ self,
253
+ model_output: torch.FloatTensor,
254
+ timestep: int,
255
+ x: torch.FloatTensor,
256
+ ):
257
+ """
258
+ predict the sample of the next step in the denoise process.
259
+ """
260
+ prev_timestep = timestep - \
261
+ self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
262
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
263
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[
264
+ prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
265
+ beta_prod_t = 1 - alpha_prod_t
266
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
267
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
268
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
269
+ return x_prev, pred_x0
270
+
271
+ @torch.no_grad()
272
+ def image2latent(self, image):
273
+ DEVICE = torch.device(
274
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
275
+ if type(image) is Image:
276
+ image = np.array(image)
277
+ image = torch.from_numpy(image).float() / 127.5 - 1
278
+ image = image.permute(2, 0, 1).unsqueeze(0)
279
+ # input image density range [-1, 1]
280
+ latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean
281
+ latents = latents * 0.18215
282
+ return latents
283
+
284
+ @torch.no_grad()
285
+ def latent2image(self, latents, return_type='np'):
286
+ latents = 1 / 0.18215 * latents.detach()
287
+ image = self.vae.decode(latents)['sample']
288
+ if return_type == 'np':
289
+ image = (image / 2 + 0.5).clamp(0, 1)
290
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
291
+ image = (image * 255).astype(np.uint8)
292
+ elif return_type == "pt":
293
+ image = (image / 2 + 0.5).clamp(0, 1)
294
+
295
+ return image
296
+
297
+ def latent2image_grad(self, latents):
298
+ latents = 1 / 0.18215 * latents
299
+ image = self.vae.decode(latents)['sample']
300
+
301
+ return image # range [-1, 1]
302
+
303
+ @torch.no_grad()
304
+ def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None):
305
+ # latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \
306
+ # torch.sin(alpha * torch.pi / 2) * img_noise_1
307
+ # latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1
308
+ # latents = latents / ((1 - alpha) ** 2 + alpha ** 2)
309
+ latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain)
310
+ text_embeddings = (1 - alpha) * text_embeddings_0 + \
311
+ alpha * text_embeddings_1
312
+
313
+ self.scheduler.set_timesteps(num_inference_steps)
314
+ if use_lora:
315
+ if fix_lora is not None:
316
+ self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora)
317
+ else:
318
+ self.unet = load_lora(self.unet, lora_0, lora_1, alpha)
319
+
320
+ for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")):
321
+
322
+ if guidance_scale > 1.:
323
+ model_inputs = torch.cat([latents] * 2)
324
+ else:
325
+ model_inputs = latents
326
+ if unconditioning is not None and isinstance(unconditioning, list):
327
+ _, text_embeddings = text_embeddings.chunk(2)
328
+ text_embeddings = torch.cat(
329
+ [unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
330
+ # predict the noise
331
+ noise_pred = self.unet(
332
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
333
+ if guidance_scale > 1.0:
334
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(
335
+ 2, dim=0)
336
+ noise_pred = noise_pred_uncon + guidance_scale * \
337
+ (noise_pred_con - noise_pred_uncon)
338
+ # compute the previous noise sample x_t -> x_t-1
339
+ latents = self.scheduler.step(
340
+ noise_pred, t, latents, return_dict=False)[0]
341
+ return latents
342
+
343
+ @torch.no_grad()
344
+ def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size):
345
+ DEVICE = torch.device(
346
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
347
+ # text embeddings
348
+ text_input = self.tokenizer(
349
+ prompt,
350
+ padding="max_length",
351
+ max_length=77,
352
+ return_tensors="pt"
353
+ )
354
+ text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0]
355
+
356
+ if guidance_scale > 1.:
357
+ if neg_prompt:
358
+ uc_text = neg_prompt
359
+ else:
360
+ uc_text = ""
361
+ unconditional_input = self.tokenizer(
362
+ [uc_text] * batch_size,
363
+ padding="max_length",
364
+ max_length=77,
365
+ return_tensors="pt"
366
+ )
367
+ unconditional_embeddings = self.text_encoder(
368
+ unconditional_input.input_ids.to(DEVICE))[0]
369
+ text_embeddings = torch.cat(
370
+ [unconditional_embeddings, text_embeddings], dim=0)
371
+
372
+ return text_embeddings
373
+
374
+ def __call__(
375
+ self,
376
+ img_0=None,
377
+ img_1=None,
378
+ img_path_0=None,
379
+ img_path_1=None,
380
+ prompt_0="",
381
+ prompt_1="",
382
+ save_lora_dir="./lora",
383
+ load_lora_path_0=None,
384
+ load_lora_path_1=None,
385
+ lora_steps=200,
386
+ lora_lr=2e-4,
387
+ lora_rank=16,
388
+ batch_size=1,
389
+ height=512,
390
+ width=512,
391
+ num_inference_steps=50,
392
+ num_actual_inference_steps=None,
393
+ guidance_scale=1,
394
+ attn_beta=0,
395
+ lamd=0.6,
396
+ use_lora=True,
397
+ use_adain=True,
398
+ use_reschedule=True,
399
+ output_path="./results",
400
+ num_frames=50,
401
+ fix_lora=None,
402
+ progress=tqdm,
403
+ unconditioning=None,
404
+ neg_prompt=None,
405
+ save_intermediates=False,
406
+ **kwds):
407
+
408
+ # if isinstance(prompt, list):
409
+ # batch_size = len(prompt)
410
+ # elif isinstance(prompt, str):
411
+ # if batch_size > 1:
412
+ # prompt = [prompt] * batch_size
413
+ self.scheduler.set_timesteps(num_inference_steps)
414
+ self.use_lora = use_lora
415
+ self.use_adain = use_adain
416
+ self.use_reschedule = use_reschedule
417
+ self.output_path = output_path
418
+
419
+ if img_0 is None:
420
+ img_0 = Image.open(img_path_0).convert("RGB")
421
+ # else:
422
+ # img_0 = Image.fromarray(img_0).convert("RGB")
423
+
424
+ if img_1 is None:
425
+ img_1 = Image.open(img_path_1).convert("RGB")
426
+ # else:
427
+ # img_1 = Image.fromarray(img_1).convert("RGB")
428
+
429
+ if self.use_lora:
430
+ print("Loading lora...")
431
+ if not load_lora_path_0:
432
+
433
+ weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
434
+ load_lora_path_0 = save_lora_dir + "/" + weight_name
435
+ if not os.path.exists(load_lora_path_0):
436
+ train_lora(img_0, prompt_0, save_lora_dir, None, self.tokenizer, self.text_encoder,
437
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
438
+ print(f"Load from {load_lora_path_0}.")
439
+ if load_lora_path_0.endswith(".safetensors"):
440
+ lora_0 = safetensors.torch.load_file(
441
+ load_lora_path_0, device="cpu")
442
+ else:
443
+ lora_0 = torch.load(load_lora_path_0, map_location="cpu")
444
+
445
+ if not load_lora_path_1:
446
+ weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
447
+ load_lora_path_1 = save_lora_dir + "/" + weight_name
448
+ if not os.path.exists(load_lora_path_1):
449
+ train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder,
450
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
451
+ print(f"Load from {load_lora_path_1}.")
452
+ if load_lora_path_1.endswith(".safetensors"):
453
+ lora_1 = safetensors.torch.load_file(
454
+ load_lora_path_1, device="cpu")
455
+ else:
456
+ lora_1 = torch.load(load_lora_path_1, map_location="cpu")
457
+ else:
458
+ lora_0 = lora_1 = None
459
+
460
+ text_embeddings_0 = self.get_text_embeddings(
461
+ prompt_0, guidance_scale, neg_prompt, batch_size)
462
+ text_embeddings_1 = self.get_text_embeddings(
463
+ prompt_1, guidance_scale, neg_prompt, batch_size)
464
+ img_0 = get_img(img_0)
465
+ img_1 = get_img(img_1)
466
+ if self.use_lora:
467
+ self.unet = load_lora(self.unet, lora_0, lora_1, 0)
468
+ img_noise_0 = self.ddim_inversion(
469
+ self.image2latent(img_0), text_embeddings_0)
470
+ if self.use_lora:
471
+ self.unet = load_lora(self.unet, lora_0, lora_1, 1)
472
+ img_noise_1 = self.ddim_inversion(
473
+ self.image2latent(img_1), text_embeddings_1)
474
+
475
+ print("latents shape: ", img_noise_0.shape)
476
+
477
+ original_processor = list(self.unet.attn_processors.values())[0]
478
+
479
+ def morph(alpha_list, progress, desc):
480
+ images = []
481
+ if attn_beta is not None:
482
+ if self.use_lora:
483
+ self.unet = load_lora(
484
+ self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora)
485
+
486
+ attn_processor_dict = {}
487
+ for k in self.unet.attn_processors.keys():
488
+ if do_replace_attn(k):
489
+ if self.use_lora:
490
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
491
+ self.img0_dict, k)
492
+ else:
493
+ attn_processor_dict[k] = StoreProcessor(original_processor,
494
+ self.img0_dict, k)
495
+ else:
496
+ attn_processor_dict[k] = self.unet.attn_processors[k]
497
+ self.unet.set_attn_processor(attn_processor_dict)
498
+
499
+ latents = self.cal_latent(
500
+ num_inference_steps,
501
+ guidance_scale,
502
+ unconditioning,
503
+ img_noise_0,
504
+ img_noise_1,
505
+ text_embeddings_0,
506
+ text_embeddings_1,
507
+ lora_0,
508
+ lora_1,
509
+ alpha_list[0],
510
+ False,
511
+ fix_lora
512
+ )
513
+ first_image = self.latent2image(latents)
514
+ first_image = Image.fromarray(first_image)
515
+ if save_intermediates:
516
+ first_image.save(f"{self.output_path}/{0:02d}.png")
517
+
518
+ if self.use_lora:
519
+ self.unet = load_lora(
520
+ self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora)
521
+ attn_processor_dict = {}
522
+ for k in self.unet.attn_processors.keys():
523
+ if do_replace_attn(k):
524
+ if self.use_lora:
525
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
526
+ self.img1_dict, k)
527
+ else:
528
+ attn_processor_dict[k] = StoreProcessor(original_processor,
529
+ self.img1_dict, k)
530
+ else:
531
+ attn_processor_dict[k] = self.unet.attn_processors[k]
532
+
533
+ self.unet.set_attn_processor(attn_processor_dict)
534
+
535
+ latents = self.cal_latent(
536
+ num_inference_steps,
537
+ guidance_scale,
538
+ unconditioning,
539
+ img_noise_0,
540
+ img_noise_1,
541
+ text_embeddings_0,
542
+ text_embeddings_1,
543
+ lora_0,
544
+ lora_1,
545
+ alpha_list[-1],
546
+ False,
547
+ fix_lora
548
+ )
549
+ last_image = self.latent2image(latents)
550
+ last_image = Image.fromarray(last_image)
551
+ if save_intermediates:
552
+ last_image.save(
553
+ f"{self.output_path}/{num_frames - 1:02d}.png")
554
+
555
+ for i in progress.tqdm(range(1, num_frames - 1), desc=desc):
556
+ alpha = alpha_list[i]
557
+ if self.use_lora:
558
+ self.unet = load_lora(
559
+ self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora)
560
+
561
+ attn_processor_dict = {}
562
+ for k in self.unet.attn_processors.keys():
563
+ if do_replace_attn(k):
564
+ if self.use_lora:
565
+ attn_processor_dict[k] = LoadProcessor(
566
+ self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
567
+ else:
568
+ attn_processor_dict[k] = LoadProcessor(
569
+ original_processor, k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
570
+ else:
571
+ attn_processor_dict[k] = self.unet.attn_processors[k]
572
+
573
+ self.unet.set_attn_processor(attn_processor_dict)
574
+
575
+ latents = self.cal_latent(
576
+ num_inference_steps,
577
+ guidance_scale,
578
+ unconditioning,
579
+ img_noise_0,
580
+ img_noise_1,
581
+ text_embeddings_0,
582
+ text_embeddings_1,
583
+ lora_0,
584
+ lora_1,
585
+ alpha_list[i],
586
+ False,
587
+ fix_lora
588
+ )
589
+ image = self.latent2image(latents)
590
+ image = Image.fromarray(image)
591
+ if save_intermediates:
592
+ image.save(f"{self.output_path}/{i:02d}.png")
593
+ images.append(image)
594
+
595
+ images = [first_image] + images + [last_image]
596
+
597
+ else:
598
+ for k, alpha in enumerate(alpha_list):
599
+
600
+ latents = self.cal_latent(
601
+ num_inference_steps,
602
+ guidance_scale,
603
+ unconditioning,
604
+ img_noise_0,
605
+ img_noise_1,
606
+ text_embeddings_0,
607
+ text_embeddings_1,
608
+ lora_0,
609
+ lora_1,
610
+ alpha_list[k],
611
+ self.use_lora,
612
+ fix_lora
613
+ )
614
+ image = self.latent2image(latents)
615
+ image = Image.fromarray(image)
616
+ if save_intermediates:
617
+ image.save(f"{self.output_path}/{k:02d}.png")
618
+ images.append(image)
619
+
620
+ return images
621
+
622
+ with torch.no_grad():
623
+ if self.use_reschedule:
624
+ alpha_scheduler = AlphaScheduler()
625
+ alpha_list = list(torch.linspace(0, 1, num_frames))
626
+ images_pt = morph(alpha_list, progress, "Sampling...")
627
+ images_pt = [transforms.ToTensor()(img).unsqueeze(0)
628
+ for img in images_pt]
629
+ alpha_scheduler.from_imgs(images_pt)
630
+ alpha_list = alpha_scheduler.get_list()
631
+ print(alpha_list)
632
+ images = morph(alpha_list, progress, "Reschedule..."
633
+ )
634
+ else:
635
+ alpha_list = list(torch.linspace(0, 1, num_frames))
636
+ print(alpha_list)
637
+ images = morph(alpha_list, progress, "Sampling...")
638
+
639
+ return images
multi_image/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Update
2
+
3
+ Add support for multi-image input. Now you can get the morphing output among more than 2 images.
4
+
5
+ ## Run the code
6
+
7
+ You can run the code with the following command:
8
+
9
+ ```
10
+ python main.py \
11
+ --image_paths [image_path_0] ... [image_path_n] \
12
+ --prompts [prompt_0] ... [prompt_n] \
13
+ --output_path [output_path] \
14
+ --use_adain --use_reschedule --save_inter
15
+ ```
16
+
17
+ This modification add support for the following options:
18
+
19
+ - `--image_paths`: Paths of the input images
20
+ - `--prompts`: Prompts of the images
21
+ - `--load_lora_paths`: Paths of the lora directory of the images
22
+
23
+ ## Example
24
+
25
+ Run the code:
26
+ ```
27
+ python main.py \
28
+ --image_paths ./assets/realdog0.jpg ./assets/realdog1.jpg ./assets/realdog2.jpg \
29
+ --prompts "A photo of a dog" "A photo of a dog" "A photo of a dog" \
30
+ --output_path "./results/dog" \
31
+ --use_adain --use_reschedule --save_inter
32
+ ```
33
+
34
+ Output:
35
+ <div align="center">
36
+ <img src="assets/realdog.gif" width="50%" height="50%">
37
+ </div>
multi_image/assets/realdog.gif ADDED

Git LFS Details

  • SHA256: a70a7919940a458f77ee8715f09d4ff5a944a19ea70d643e17ab2f555034ed6a
  • Pointer size: 132 Bytes
  • Size of remote file: 5.21 MB
multi_image/assets/realdog0.jpg ADDED
multi_image/assets/realdog1.jpg ADDED
multi_image/assets/realdog2.jpg ADDED
multi_image/main.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ from argparse import ArgumentParser
7
+ from model import DiffMorpherPipeline
8
+
9
+ parser = ArgumentParser()
10
+ parser.add_argument(
11
+ "--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base",
12
+ help="Pretrained model to use (default: %(default)s)"
13
+ )
14
+ parser.add_argument(
15
+ "--image_path_0", type=str, default="",
16
+ help="Path of the first image (default: %(default)s)")
17
+ parser.add_argument(
18
+ "--prompt_0", type=str, default="",
19
+ help="Prompt of the second image (default: %(default)s)")
20
+ parser.add_argument(
21
+ "--image_path_1", type=str, default="",
22
+ help="Path of the first image (default: %(default)s)")
23
+ parser.add_argument(
24
+ "--prompt_1", type=str, default="",
25
+ help="Prompt of the second image (default: %(default)s)")
26
+ parser.add_argument(
27
+ "--load_lora_path_0", type=str, default="",
28
+ help="Path of the lora directory of the first image (default: %(default)s)"
29
+ )
30
+ parser.add_argument(
31
+ "--load_lora_path_1", type=str, default="",
32
+ help="Path of the lora directory of the second image (default: %(default)s)"
33
+ )
34
+ parser.add_argument(
35
+ "--image_paths", type=str, nargs='*', default=[],
36
+ help="Path of the first image (default: %(default)s)")
37
+ parser.add_argument(
38
+ "--prompts", type=str, nargs='*', default=[],
39
+ help="Prompt of the second image (default: %(default)s)")
40
+ parser.add_argument(
41
+ "--output_path", type=str, default="./results",
42
+ help="Path of the output image (default: %(default)s)"
43
+ )
44
+ parser.add_argument(
45
+ "--save_lora_dir", type=str, default="./lora",
46
+ help="Path of the output lora directory (default: %(default)s)"
47
+ )
48
+ parser.add_argument(
49
+ "--load_lora_paths", type=str, nargs='*', default=[],
50
+ help="Path of the lora directory of the first image (default: %(default)s)"
51
+ )
52
+ parser.add_argument(
53
+ "--use_adain", action="store_true",
54
+ help="Use AdaIN (default: %(default)s)"
55
+ )
56
+ parser.add_argument(
57
+ "--use_reschedule", action="store_true",
58
+ help="Use reschedule sampling (default: %(default)s)"
59
+ )
60
+ parser.add_argument(
61
+ "--lamb", type=float, default=0.6,
62
+ help="Lambda for self-attention replacement (default: %(default)s)"
63
+ )
64
+ parser.add_argument(
65
+ "--fix_lora_value", type=float, default=None,
66
+ help="Fix lora value (default: LoRA Interp., not fixed)"
67
+ )
68
+ parser.add_argument(
69
+ "--save_inter", action="store_true",
70
+ help="Save intermediate results (default: %(default)s)"
71
+ )
72
+ parser.add_argument(
73
+ "--num_frames", type=int, default=16,
74
+ help="Number of frames to generate (default: %(default)s)"
75
+ )
76
+ parser.add_argument(
77
+ "--duration", type=int, default=100,
78
+ help="Duration of each frame (default: %(default)s ms)"
79
+ )
80
+
81
+ args = parser.parse_args()
82
+
83
+ os.makedirs(args.output_path, exist_ok=True)
84
+ pipeline = DiffMorpherPipeline.from_pretrained(
85
+ args.model_path, torch_dtype=torch.float32)
86
+ pipeline.to("cuda")
87
+ images = pipeline(
88
+ img_path_0=args.image_path_0,
89
+ img_path_1=args.image_path_1,
90
+ prompt_0=args.prompt_0,
91
+ prompt_1=args.prompt_1,
92
+ load_lora_path_0=args.load_lora_path_0,
93
+ load_lora_path_1=args.load_lora_path_1,
94
+ img_paths=args.image_paths,
95
+ prompts=args.prompts,
96
+ save_lora_dir=args.save_lora_dir,
97
+ load_lora_paths=args.load_lora_paths,
98
+ use_adain=args.use_adain,
99
+ use_reschedule=args.use_reschedule,
100
+ lamb=args.lamb,
101
+ output_path=args.output_path,
102
+ num_frames=args.num_frames,
103
+ fix_lora=args.fix_lora_value,
104
+ save_intermediates=args.save_inter,
105
+ )
106
+ images[0].save(f"{args.output_path}/output.gif", save_all=True,
107
+ append_images=images[1:], duration=args.duration, loop=0)
multi_image/model.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
3
+ from diffusers.models.attention_processor import AttnProcessor
4
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
5
+ from diffusers.schedulers import KarrasDiffusionSchedulers
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import tqdm
9
+ import numpy as np
10
+ import safetensors
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
14
+ from diffusers import StableDiffusionPipeline
15
+ from argparse import ArgumentParser
16
+
17
+
18
+ from utils.model_utils import get_img, slerp, do_replace_attn
19
+ from utils.lora_utils import train_lora, load_lora
20
+ from utils.alpha_scheduler import AlphaScheduler
21
+
22
+ class StoreProcessor():
23
+ def __init__(self, original_processor, value_dict, name):
24
+ self.original_processor = original_processor
25
+ self.value_dict = value_dict
26
+ self.name = name
27
+ self.value_dict[self.name] = dict()
28
+ self.id = 0
29
+
30
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
31
+ # Is self attention
32
+ if encoder_hidden_states is None:
33
+ self.value_dict[self.name][self.id] = hidden_states.detach()
34
+ self.id += 1
35
+ res = self.original_processor(attn, hidden_states, *args,
36
+ encoder_hidden_states=encoder_hidden_states,
37
+ attention_mask=attention_mask,
38
+ **kwargs)
39
+
40
+ return res
41
+
42
+
43
+ class LoadProcessor():
44
+ def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6):
45
+ super().__init__()
46
+ self.original_processor = original_processor
47
+ self.name = name
48
+ self.img0_dict = img0_dict
49
+ self.img1_dict = img1_dict
50
+ self.alpha = alpha
51
+ self.beta = beta
52
+ self.lamd = lamd
53
+ self.id = 0
54
+
55
+ def parent_call(
56
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
57
+ ):
58
+ residual = hidden_states
59
+
60
+ if attn.spatial_norm is not None:
61
+ hidden_states = attn.spatial_norm(hidden_states)
62
+
63
+ input_ndim = hidden_states.ndim
64
+
65
+ if input_ndim == 4:
66
+ batch_size, channel, height, width = hidden_states.shape
67
+ hidden_states = hidden_states.view(
68
+ batch_size, channel, height * width).transpose(1, 2)
69
+
70
+ batch_size, sequence_length, _ = (
71
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
72
+ )
73
+ attention_mask = attn.prepare_attention_mask(
74
+ attention_mask, sequence_length, batch_size)
75
+
76
+ if attn.group_norm is not None:
77
+ hidden_states = attn.group_norm(
78
+ hidden_states.transpose(1, 2)).transpose(1, 2)
79
+
80
+ query = attn.to_q(hidden_states) + scale * \
81
+ self.original_processor.to_q_lora(hidden_states)
82
+ query = attn.head_to_batch_dim(query)
83
+
84
+ if encoder_hidden_states is None:
85
+ encoder_hidden_states = hidden_states
86
+ elif attn.norm_cross:
87
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
88
+ encoder_hidden_states)
89
+
90
+ key = attn.to_k(encoder_hidden_states) + scale * \
91
+ self.original_processor.to_k_lora(encoder_hidden_states)
92
+ value = attn.to_v(encoder_hidden_states) + scale * \
93
+ self.original_processor.to_v_lora(encoder_hidden_states)
94
+
95
+ key = attn.head_to_batch_dim(key)
96
+ value = attn.head_to_batch_dim(value)
97
+
98
+ attention_probs = attn.get_attention_scores(
99
+ query, key, attention_mask)
100
+ hidden_states = torch.bmm(attention_probs, value)
101
+ hidden_states = attn.batch_to_head_dim(hidden_states)
102
+
103
+ # linear proj
104
+ hidden_states = attn.to_out[0](
105
+ hidden_states) + scale * self.original_processor.to_out_lora(hidden_states)
106
+ # dropout
107
+ hidden_states = attn.to_out[1](hidden_states)
108
+
109
+ if input_ndim == 4:
110
+ hidden_states = hidden_states.transpose(
111
+ -1, -2).reshape(batch_size, channel, height, width)
112
+
113
+ if attn.residual_connection:
114
+ hidden_states = hidden_states + residual
115
+
116
+ hidden_states = hidden_states / attn.rescale_output_factor
117
+
118
+ return hidden_states
119
+
120
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
121
+ # Is self attention
122
+ if encoder_hidden_states is None:
123
+ # hardcode timestep
124
+ if self.id < 50 * self.lamd:
125
+ map0 = self.img0_dict[self.name][self.id]
126
+ map1 = self.img1_dict[self.name][self.id]
127
+ cross_map = self.beta * hidden_states + \
128
+ (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
129
+ # cross_map = self.beta * hidden_states + \
130
+ # (1 - self.beta) * slerp(map0, map1, self.alpha)
131
+ # cross_map = slerp(slerp(map0, map1, self.alpha),
132
+ # hidden_states, self.beta)
133
+ # cross_map = hidden_states
134
+ # cross_map = torch.cat(
135
+ # ((1 - self.alpha) * map0, self.alpha * map1), dim=1)
136
+
137
+ # res = self.original_processor(attn, hidden_states, *args,
138
+ # encoder_hidden_states=cross_map,
139
+ # attention_mask=attention_mask,
140
+ # temb=temb, **kwargs)
141
+ res = self.parent_call(attn, hidden_states, *args,
142
+ encoder_hidden_states=cross_map,
143
+ attention_mask=attention_mask,
144
+ **kwargs)
145
+ else:
146
+ res = self.original_processor(attn, hidden_states, *args,
147
+ encoder_hidden_states=encoder_hidden_states,
148
+ attention_mask=attention_mask,
149
+ **kwargs)
150
+
151
+ self.id += 1
152
+ # if self.id == len(self.img0_dict[self.name]):
153
+ if self.id == len(self.img0_dict[self.name]):
154
+ self.id = 0
155
+ else:
156
+ res = self.original_processor(attn, hidden_states, *args,
157
+ encoder_hidden_states=encoder_hidden_states,
158
+ attention_mask=attention_mask,
159
+ **kwargs)
160
+
161
+ return res
162
+
163
+
164
+ class DiffMorpherPipeline(StableDiffusionPipeline):
165
+
166
+ def __init__(self,
167
+ vae: AutoencoderKL,
168
+ text_encoder: CLIPTextModel,
169
+ tokenizer: CLIPTokenizer,
170
+ unet: UNet2DConditionModel,
171
+ scheduler: KarrasDiffusionSchedulers,
172
+ safety_checker: StableDiffusionSafetyChecker,
173
+ feature_extractor: CLIPImageProcessor,
174
+ requires_safety_checker: bool = True,
175
+ ):
176
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
177
+ safety_checker, feature_extractor, requires_safety_checker)
178
+ self.img0_dict = dict()
179
+ self.img1_dict = dict()
180
+
181
+ def inv_step(
182
+ self,
183
+ model_output: torch.FloatTensor,
184
+ timestep: int,
185
+ x: torch.FloatTensor,
186
+ eta=0.,
187
+ verbose=False
188
+ ):
189
+ """
190
+ Inverse sampling for DDIM Inversion
191
+ """
192
+ if verbose:
193
+ print("timestep: ", timestep)
194
+ next_step = timestep
195
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps //
196
+ self.scheduler.num_inference_steps, 999)
197
+ alpha_prod_t = self.scheduler.alphas_cumprod[
198
+ timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
199
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
200
+ beta_prod_t = 1 - alpha_prod_t
201
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
202
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
203
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
204
+ return x_next, pred_x0
205
+
206
+ @torch.no_grad()
207
+ def invert(
208
+ self,
209
+ image: torch.Tensor,
210
+ prompt,
211
+ num_inference_steps=50,
212
+ num_actual_inference_steps=None,
213
+ guidance_scale=1.,
214
+ eta=0.0,
215
+ **kwds):
216
+ """
217
+ invert a real image into noise map with determinisc DDIM inversion
218
+ """
219
+ DEVICE = torch.device(
220
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
221
+ batch_size = image.shape[0]
222
+ if isinstance(prompt, list):
223
+ if batch_size == 1:
224
+ image = image.expand(len(prompt), -1, -1, -1)
225
+ elif isinstance(prompt, str):
226
+ if batch_size > 1:
227
+ prompt = [prompt] * batch_size
228
+
229
+ # text embeddings
230
+ text_input = self.tokenizer(
231
+ prompt,
232
+ padding="max_length",
233
+ max_length=77,
234
+ return_tensors="pt"
235
+ )
236
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
237
+ print("input text embeddings :", text_embeddings.shape)
238
+ # define initial latents
239
+ latents = self.image2latent(image)
240
+
241
+ # unconditional embedding for classifier free guidance
242
+ if guidance_scale > 1.:
243
+ max_length = text_input.input_ids.shape[-1]
244
+ unconditional_input = self.tokenizer(
245
+ [""] * batch_size,
246
+ padding="max_length",
247
+ max_length=77,
248
+ return_tensors="pt"
249
+ )
250
+ unconditional_embeddings = self.text_encoder(
251
+ unconditional_input.input_ids.to(DEVICE))[0]
252
+ text_embeddings = torch.cat(
253
+ [unconditional_embeddings, text_embeddings], dim=0)
254
+
255
+ print("latents shape: ", latents.shape)
256
+ # interative sampling
257
+ self.scheduler.set_timesteps(num_inference_steps)
258
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
259
+ # print("attributes: ", self.scheduler.__dict__)
260
+ latents_list = [latents]
261
+ pred_x0_list = [latents]
262
+ for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
263
+ if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
264
+ continue
265
+
266
+ if guidance_scale > 1.:
267
+ model_inputs = torch.cat([latents] * 2)
268
+ else:
269
+ model_inputs = latents
270
+
271
+ # predict the noise
272
+ noise_pred = self.unet(
273
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
274
+ if guidance_scale > 1.:
275
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
276
+ noise_pred = noise_pred_uncon + guidance_scale * \
277
+ (noise_pred_con - noise_pred_uncon)
278
+ # compute the previous noise sample x_t-1 -> x_t
279
+ latents, pred_x0 = self.inv_step(noise_pred, t, latents)
280
+ latents_list.append(latents)
281
+ pred_x0_list.append(pred_x0)
282
+
283
+ return latents
284
+
285
+ @torch.no_grad()
286
+ def ddim_inversion(self, latent, cond):
287
+ timesteps = reversed(self.scheduler.timesteps)
288
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
289
+ for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")):
290
+ cond_batch = cond.repeat(latent.shape[0], 1, 1)
291
+
292
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
293
+ alpha_prod_t_prev = (
294
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
295
+ if i > 0 else self.scheduler.final_alpha_cumprod
296
+ )
297
+
298
+ mu = alpha_prod_t ** 0.5
299
+ mu_prev = alpha_prod_t_prev ** 0.5
300
+ sigma = (1 - alpha_prod_t) ** 0.5
301
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
302
+
303
+ eps = self.unet(
304
+ latent, t, encoder_hidden_states=cond_batch).sample
305
+
306
+ pred_x0 = (latent - sigma_prev * eps) / mu_prev
307
+ latent = mu * pred_x0 + sigma * eps
308
+ # if save_latents:
309
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
310
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
311
+ return latent
312
+
313
+ def step(
314
+ self,
315
+ model_output: torch.FloatTensor,
316
+ timestep: int,
317
+ x: torch.FloatTensor,
318
+ ):
319
+ """
320
+ predict the sample of the next step in the denoise process.
321
+ """
322
+ prev_timestep = timestep - \
323
+ self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
324
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
325
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[
326
+ prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
327
+ beta_prod_t = 1 - alpha_prod_t
328
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
329
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
330
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
331
+ return x_prev, pred_x0
332
+
333
+ @torch.no_grad()
334
+ def image2latent(self, image):
335
+ DEVICE = torch.device(
336
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
337
+ if type(image) is Image:
338
+ image = np.array(image)
339
+ image = torch.from_numpy(image).float() / 127.5 - 1
340
+ image = image.permute(2, 0, 1).unsqueeze(0)
341
+ # input image density range [-1, 1]
342
+ latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean
343
+ latents = latents * 0.18215
344
+ return latents
345
+
346
+ @torch.no_grad()
347
+ def latent2image(self, latents, return_type='np'):
348
+ latents = 1 / 0.18215 * latents.detach()
349
+ image = self.vae.decode(latents)['sample']
350
+ if return_type == 'np':
351
+ image = (image / 2 + 0.5).clamp(0, 1)
352
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
353
+ image = (image * 255).astype(np.uint8)
354
+ elif return_type == "pt":
355
+ image = (image / 2 + 0.5).clamp(0, 1)
356
+
357
+ return image
358
+
359
+ def latent2image_grad(self, latents):
360
+ latents = 1 / 0.18215 * latents
361
+ image = self.vae.decode(latents)['sample']
362
+
363
+ return image # range [-1, 1]
364
+
365
+ @torch.no_grad()
366
+ def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None):
367
+ # latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \
368
+ # torch.sin(alpha * torch.pi / 2) * img_noise_1
369
+ # latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1
370
+ # latents = latents / ((1 - alpha) ** 2 + alpha ** 2)
371
+ latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain)
372
+ text_embeddings = (1 - alpha) * text_embeddings_0 + \
373
+ alpha * text_embeddings_1
374
+
375
+ self.scheduler.set_timesteps(num_inference_steps)
376
+ if use_lora:
377
+ if fix_lora is not None:
378
+ self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora)
379
+ else:
380
+ self.unet = load_lora(self.unet, lora_0, lora_1, alpha)
381
+
382
+ for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")):
383
+
384
+ if guidance_scale > 1.:
385
+ model_inputs = torch.cat([latents] * 2)
386
+ else:
387
+ model_inputs = latents
388
+ if unconditioning is not None and isinstance(unconditioning, list):
389
+ _, text_embeddings = text_embeddings.chunk(2)
390
+ text_embeddings = torch.cat(
391
+ [unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
392
+ # predict the noise
393
+ noise_pred = self.unet(
394
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
395
+ if guidance_scale > 1.0:
396
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(
397
+ 2, dim=0)
398
+ noise_pred = noise_pred_uncon + guidance_scale * \
399
+ (noise_pred_con - noise_pred_uncon)
400
+ # compute the previous noise sample x_t -> x_t-1
401
+ latents = self.scheduler.step(
402
+ noise_pred, t, latents, return_dict=False)[0]
403
+ return latents
404
+
405
+ @torch.no_grad()
406
+ def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size):
407
+ DEVICE = torch.device(
408
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
409
+ # text embeddings
410
+ text_input = self.tokenizer(
411
+ prompt,
412
+ padding="max_length",
413
+ max_length=77,
414
+ return_tensors="pt"
415
+ )
416
+ text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0]
417
+
418
+ if guidance_scale > 1.:
419
+ if neg_prompt:
420
+ uc_text = neg_prompt
421
+ else:
422
+ uc_text = ""
423
+ unconditional_input = self.tokenizer(
424
+ [uc_text] * batch_size,
425
+ padding="max_length",
426
+ max_length=77,
427
+ return_tensors="pt"
428
+ )
429
+ unconditional_embeddings = self.text_encoder(
430
+ unconditional_input.input_ids.to(DEVICE))[0]
431
+ text_embeddings = torch.cat(
432
+ [unconditional_embeddings, text_embeddings], dim=0)
433
+
434
+ return text_embeddings
435
+
436
+ def __call__(
437
+ self,
438
+ img_0=None,
439
+ img_1=None,
440
+ img_path_0=None,
441
+ img_path_1=None,
442
+ prompt_0="",
443
+ prompt_1="",
444
+ imgs=[],
445
+ img_paths=None,
446
+ prompts=None,
447
+ save_lora_dir="./lora",
448
+ load_lora_path_0=None,
449
+ load_lora_path_1=None,
450
+ load_lora_paths=None,
451
+ lora_steps=200,
452
+ lora_lr=2e-4,
453
+ lora_rank=16,
454
+ batch_size=1,
455
+ height=512,
456
+ width=512,
457
+ num_inference_steps=50,
458
+ num_actual_inference_steps=None,
459
+ guidance_scale=1,
460
+ attn_beta=0,
461
+ lamd=0.6,
462
+ use_lora=True,
463
+ use_adain=True,
464
+ use_reschedule=True,
465
+ output_path = "./results",
466
+ num_frames=50,
467
+ fix_lora=None,
468
+ progress=tqdm,
469
+ unconditioning=None,
470
+ neg_prompt=None,
471
+ save_intermediates=False,
472
+ **kwds):
473
+
474
+
475
+ self.scheduler.set_timesteps(num_inference_steps)
476
+ self.use_lora = use_lora
477
+ self.use_adain = use_adain
478
+ self.use_reschedule = use_reschedule
479
+ self.output_path = output_path
480
+
481
+
482
+ imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths]
483
+ assert len(prompts) == len(imgs)
484
+
485
+ # if img_path_0 or img_0:
486
+ # img_paths = [img_path_0, img_path_1]
487
+ # prompts = [prompt_0, prompt_1]
488
+ # load_lora_paths = [load_lora_path_0, load_lora_path_1]
489
+
490
+ # if img_0:
491
+ # imgs.append(Image.fromarray(img_0).convert("RGB"))
492
+ # if img_1:
493
+ # imgs.append(Image.fromarray(img_1).convert("RGB"))
494
+ # if imgs is None:
495
+ # imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths]
496
+ # if len(prompts) < len(imgs):
497
+ # prompts += ["" for _ in range(len(imgs) - len(prompts))]
498
+
499
+ if self.use_lora:
500
+ loras = []
501
+ print("Loading lora...")
502
+ for i, (img, prompt) in enumerate(zip(imgs, prompts)):
503
+ if len(load_lora_paths) == i:
504
+
505
+ weight_name = f"{output_path.split('/')[-1]}_lora_{i}.ckpt"
506
+ load_lora_paths.append(save_lora_dir + "/" + weight_name)
507
+ if not os.path.exists(load_lora_paths[i]):
508
+ train_lora(img, prompt, save_lora_dir, None, self.tokenizer, self.text_encoder,
509
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
510
+ print(f"Load from {load_lora_paths[i]}.")
511
+ if load_lora_paths[i].endswith(".safetensors"):
512
+ loras.append(safetensors.torch.load_file(
513
+ load_lora_paths[i], device="cpu"))
514
+ else:
515
+ loras.append(torch.load(load_lora_paths[i], map_location="cpu"))
516
+
517
+ # if not load_lora_path_1:
518
+ # weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
519
+ # load_lora_path_1 = save_lora_dir + "/" + weight_name
520
+ # if not os.path.exists(load_lora_path_1):
521
+ # train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder,
522
+ # self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
523
+ # print(f"Load from {load_lora_path_1}.")
524
+ # if load_lora_path_1.endswith(".safetensors"):
525
+ # lora_1 = safetensors.torch.load_file(
526
+ # load_lora_path_1, device="cpu")
527
+ # else:
528
+ # lora_1 = torch.load(load_lora_path_1, map_location="cpu")
529
+
530
+ def morph(alpha_list, progress, desc, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1):
531
+ images = []
532
+ if attn_beta is not None:
533
+
534
+ self.unet = load_lora(self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora)
535
+ attn_processor_dict = {}
536
+ for k in self.unet.attn_processors.keys():
537
+ if do_replace_attn(k):
538
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
539
+ self.img0_dict, k)
540
+ else:
541
+ attn_processor_dict[k] = self.unet.attn_processors[k]
542
+ self.unet.set_attn_processor(attn_processor_dict)
543
+
544
+ latents = self.cal_latent(
545
+ num_inference_steps,
546
+ guidance_scale,
547
+ unconditioning,
548
+ img_noise_0,
549
+ img_noise_1,
550
+ text_embeddings_0,
551
+ text_embeddings_1,
552
+ lora_0,
553
+ lora_1,
554
+ alpha_list[0],
555
+ False,
556
+ fix_lora
557
+ )
558
+ first_image = self.latent2image(latents)
559
+ first_image = Image.fromarray(first_image)
560
+ # if save_intermediates:
561
+ # first_image.save(f"{self.output_path}/{0:02d}.png")
562
+
563
+ self.unet = load_lora(self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora)
564
+ attn_processor_dict = {}
565
+ for k in self.unet.attn_processors.keys():
566
+ if do_replace_attn(k):
567
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
568
+ self.img1_dict, k)
569
+ else:
570
+ attn_processor_dict[k] = self.unet.attn_processors[k]
571
+
572
+ self.unet.set_attn_processor(attn_processor_dict)
573
+
574
+ latents = self.cal_latent(
575
+ num_inference_steps,
576
+ guidance_scale,
577
+ unconditioning,
578
+ img_noise_0,
579
+ img_noise_1,
580
+ text_embeddings_0,
581
+ text_embeddings_1,
582
+ lora_0,
583
+ lora_1,
584
+ alpha_list[-1],
585
+ False,
586
+ fix_lora
587
+ )
588
+ last_image = self.latent2image(latents)
589
+ last_image = Image.fromarray(last_image)
590
+ # if save_intermediates:
591
+ # last_image.save(
592
+ # f"{self.output_path}/{num_frames - 1:02d}.png")
593
+
594
+ for i in progress.tqdm(range(1, num_frames - 1), desc=desc):
595
+ alpha = alpha_list[i]
596
+ self.unet = load_lora(self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora)
597
+ attn_processor_dict = {}
598
+ for k in self.unet.attn_processors.keys():
599
+ if do_replace_attn(k):
600
+ attn_processor_dict[k] = LoadProcessor(
601
+ self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
602
+ else:
603
+ attn_processor_dict[k] = self.unet.attn_processors[k]
604
+
605
+ self.unet.set_attn_processor(attn_processor_dict)
606
+
607
+ latents = self.cal_latent(
608
+ num_inference_steps,
609
+ guidance_scale,
610
+ unconditioning,
611
+ img_noise_0,
612
+ img_noise_1,
613
+ text_embeddings_0,
614
+ text_embeddings_1,
615
+ lora_0,
616
+ lora_1,
617
+ alpha_list[i],
618
+ False,
619
+ fix_lora
620
+ )
621
+ image = self.latent2image(latents)
622
+ image = Image.fromarray(image)
623
+ # if save_intermediates:
624
+ # image.save(f"{self.output_path}/{i:02d}.png")
625
+ images.append(image)
626
+
627
+ images = [first_image] + images + [last_image]
628
+
629
+ else:
630
+ for k, alpha in enumerate(alpha_list):
631
+
632
+ latents = self.cal_latent(
633
+ num_inference_steps,
634
+ guidance_scale,
635
+ unconditioning,
636
+ img_noise_0,
637
+ img_noise_1,
638
+ text_embeddings_0,
639
+ text_embeddings_1,
640
+ lora_0,
641
+ lora_1,
642
+ alpha_list[k],
643
+ self.use_lora,
644
+ fix_lora
645
+ )
646
+ image = self.latent2image(latents)
647
+ image = Image.fromarray(image)
648
+ # if save_intermediates:
649
+ # image.save(f"{self.output_path}/{k:02d}.png")
650
+ images.append(image)
651
+
652
+ return images
653
+
654
+ images = []
655
+
656
+ for img_0, img_1, prompt_0, prompt_1, lora_0, lora_1 in zip(imgs[:-1], imgs[1:], prompts[:-1], prompts[1:], loras[:-1], loras[1:]):
657
+ text_embeddings_0 = self.get_text_embeddings(
658
+ prompt_0, guidance_scale, neg_prompt, batch_size)
659
+ text_embeddings_1 = self.get_text_embeddings(
660
+ prompt_1, guidance_scale, neg_prompt, batch_size)
661
+ img_0 = get_img(img_0)
662
+ img_1 = get_img(img_1)
663
+ if self.use_lora:
664
+ self.unet = load_lora(self.unet, lora_0, lora_1, 0)
665
+ img_noise_0 = self.ddim_inversion(
666
+ self.image2latent(img_0), text_embeddings_0)
667
+ if self.use_lora:
668
+ self.unet = load_lora(self.unet, lora_0, lora_1, 1)
669
+ img_noise_1 = self.ddim_inversion(
670
+ self.image2latent(img_1), text_embeddings_1)
671
+
672
+ print("latents shape: ", img_noise_0.shape)
673
+
674
+ with torch.no_grad():
675
+ if self.use_reschedule:
676
+ alpha_scheduler = AlphaScheduler()
677
+ alpha_list = list(torch.linspace(0, 1, num_frames))
678
+ images_pt = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1)
679
+ images_pt = [transforms.ToTensor()(img).unsqueeze(0)
680
+ for img in images_pt]
681
+ alpha_scheduler.from_imgs(images_pt)
682
+ alpha_list = alpha_scheduler.get_list()
683
+ print(alpha_list)
684
+ images_ = morph(alpha_list, progress, "Reschedule...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1)
685
+ else:
686
+ alpha_list = list(torch.linspace(0, 1, num_frames))
687
+ print(alpha_list)
688
+ images_ = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1)
689
+
690
+ if len(images) == 0:
691
+ images = images_
692
+ else:
693
+ images += images_[1:]
694
+
695
+ if save_intermediates:
696
+ for i, image in enumerate(images):
697
+ image.save(f"{self.output_path}/{i:02d}.png")
698
+
699
+ return images
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ diffusers==0.17.1
3
+ einops==0.7.0
4
+ gradio==4.7.1
5
+ numpy==1.26.1
6
+ opencv_python==4.5.5.64
7
+ packaging==23.2
8
+ Pillow==10.1.0
9
+ safetensors==0.4.0
10
+ tqdm==4.65.0
11
+ transformers==4.34.1
12
+ torch
13
+ torchvision
14
+ lpips