Upload 53 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +5 -0
- LICENSE.txt +10 -0
- README.md +131 -0
- app.py +315 -0
- assets/Biden.jpg +0 -0
- assets/Feifei.jpg +0 -0
- assets/Musk.jpg +0 -0
- assets/Teaser.png +3 -0
- assets/Trump.jpg +0 -0
- assets/cat.png +0 -0
- assets/dog.png +3 -0
- assets/dog_sit.png +3 -0
- assets/drag_realgirl0.png +0 -0
- assets/drag_realgirl1.png +0 -0
- assets/drag_sculp0.png +0 -0
- assets/drag_sculp1.png +0 -0
- assets/fuji_0.jpg +0 -0
- assets/fuji_1.jpg +0 -0
- assets/house0.jpg +0 -0
- assets/house1.jpg +0 -0
- assets/jeep.jpg +0 -0
- assets/leo_0.jpg +0 -0
- assets/leo_1.jpg +0 -0
- assets/lion.png +0 -0
- assets/man_paint.png +0 -0
- assets/mit.jpg +0 -0
- assets/monalisa.jpeg +0 -0
- assets/obama.jpg +0 -0
- assets/pearlgirl.jpg +0 -0
- assets/rabbit.png +0 -0
- assets/sculp0.png +0 -0
- assets/sculp1.png +0 -0
- assets/teaser.gif +3 -0
- assets/thu.jpg +0 -0
- assets/tiger.png +0 -0
- assets/van.jpg +0 -0
- assets/vangogh.jpg +0 -0
- assets/vangogh_hat.png +0 -0
- assets/wave_paint.png +0 -0
- assets/wave_real.jpg +0 -0
- main.py +98 -0
- model.py +639 -0
- multi_image/README.md +37 -0
- multi_image/assets/realdog.gif +3 -0
- multi_image/assets/realdog0.jpg +0 -0
- multi_image/assets/realdog1.jpg +0 -0
- multi_image/assets/realdog2.jpg +0 -0
- multi_image/main.py +107 -0
- multi_image/model.py +699 -0
- requirements.txt +14 -0
.gitattributes
CHANGED
@@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/dog_sit.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/dog.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/Teaser.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
multi_image/assets/realdog.gif filter=lfs diff=lfs merge=lfs -text
|
LICENSE.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
S-Lab License 1.0
|
2 |
+
|
3 |
+
Copyright 2023 S-Lab
|
4 |
+
Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
6 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
7 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
8 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
9 |
+
4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
|
10 |
+
|
README.md
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<h1 align="center">DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing</h1>
|
3 |
+
<h3 align="center">CVPR 2024</h3>
|
4 |
+
<p align="center">
|
5 |
+
<a href="https://kevin-thu.github.io/homepage/"><strong>Kaiwen Zhang</strong></a>
|
6 |
+
|
7 |
+
<a href="https://zhouyifan.net/about/"><strong>Yifan Zhou</strong></a>
|
8 |
+
|
9 |
+
<a href="https://sheldontsui.github.io/"><strong>Xudong Xu</strong></a>
|
10 |
+
|
11 |
+
<a href="https://xingangpan.github.io/"><strong>Xingang Pan<sep>✉</sep></strong></a>
|
12 |
+
|
13 |
+
<a href="http://daibo.info/"><strong>Bo Dai</strong></a>
|
14 |
+
</p>
|
15 |
+
<br>
|
16 |
+
|
17 |
+
<p align="center">
|
18 |
+
<sep>✉</sep>Corresponding Author
|
19 |
+
</p>
|
20 |
+
|
21 |
+
<div align="center">
|
22 |
+
<img src="./assets/teaser.gif", width="500">
|
23 |
+
</div>
|
24 |
+
|
25 |
+
<p align="center">
|
26 |
+
<a href="https://arxiv.org/abs/2312.07409"><img alt='arXiv' src="https://img.shields.io/badge/arXiv-2312.07409-b31b1b.svg"></a>
|
27 |
+
<a href="https://kevin-thu.github.io/DiffMorpher_page/"><img alt='page' src="https://img.shields.io/badge/Project-Website-orange"></a>
|
28 |
+
<a href="https://twitter.com/sze68zkw"><img alt='Twitter' src="https://img.shields.io/twitter/follow/sze68zkw?label=%40KaiwenZhang"></a>
|
29 |
+
<a href="https://twitter.com/XingangP"><img alt='Twitter' src="https://img.shields.io/twitter/follow/XingangP?label=%40XingangPan"></a>
|
30 |
+
</p>
|
31 |
+
<br>
|
32 |
+
</p>
|
33 |
+
|
34 |
+
## Web Demos
|
35 |
+
|
36 |
+
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/KaiwenZhang/DiffMorpher)
|
37 |
+
|
38 |
+
<p align="left">
|
39 |
+
<a href="https://huggingface.co/spaces/Kevin-thu/DiffMorpher"><img alt="Huggingface" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DiffMorpher-orange"></a>
|
40 |
+
</p>
|
41 |
+
|
42 |
+
<!-- Great thanks to [OpenXLab](https://openxlab.org.cn/home) for the NVIDIA A100 GPU support! -->
|
43 |
+
|
44 |
+
## Requirements
|
45 |
+
To install the requirements, run the following in your environment first:
|
46 |
+
```bash
|
47 |
+
pip install -r requirements.txt
|
48 |
+
```
|
49 |
+
To run the code with CUDA properly, you can comment out `torch` and `torchvision` in `requirement.txt`, and install the appropriate version of `torch` and `torchvision` according to the instructions on [PyTorch](https://pytorch.org/get-started/locally/).
|
50 |
+
|
51 |
+
You can also download the pretrained model *Stable Diffusion v2.1-base* from [Huggingface](https://huggingface.co/stabilityai/stable-diffusion-2-1-base), and specify the `model_path` to your local directory.
|
52 |
+
|
53 |
+
## Run Gradio UI
|
54 |
+
To start the Gradio UI of DiffMorpher, run the following in your environment:
|
55 |
+
```bash
|
56 |
+
python app.py
|
57 |
+
```
|
58 |
+
Then, by default, you can access the UI at [http://127.0.0.1:7860](http://127.0.0.1:7860).
|
59 |
+
|
60 |
+
## Run the code
|
61 |
+
You can also run the code with the following command:
|
62 |
+
```bash
|
63 |
+
python main.py \
|
64 |
+
--image_path_0 [image_path_0] --image_path_1 [image_path_1] \
|
65 |
+
--prompt_0 [prompt_0] --prompt_1 [prompt_1] \
|
66 |
+
--output_path [output_path] \
|
67 |
+
--use_adain --use_reschedule --save_inter
|
68 |
+
```
|
69 |
+
The script also supports the following options:
|
70 |
+
|
71 |
+
- `--image_path_0`: Path of the first image (default: "")
|
72 |
+
- `--prompt_0`: Prompt of the first image (default: "")
|
73 |
+
- `--image_path_1`: Path of the second image (default: "")
|
74 |
+
- `--prompt_1`: Prompt of the second image (default: "")
|
75 |
+
- `--model_path`: Pretrained model path (default: "stabilityai/stable-diffusion-2-1-base")
|
76 |
+
- `--output_path`: Path of the output image (default: "")
|
77 |
+
- `--save_lora_dir`: Path of the output lora directory (default: "./lora")
|
78 |
+
- `--load_lora_path_0`: Path of the lora directory of the first image (default: "")
|
79 |
+
- `--load_lora_path_1`: Path of the lora directory of the second image (default: "")
|
80 |
+
- `--use_adain`: Use AdaIN (default: False)
|
81 |
+
- `--use_reschedule`: Use reschedule sampling (default: False)
|
82 |
+
- `--lamb`: Hyperparameter $\lambda \in [0,1]$ for self-attention replacement, where a larger $\lambda$ indicates more replacements (default: 0.6)
|
83 |
+
- `--fix_lora_value`: Fix lora value (default: LoRA Interpolation, not fixed)
|
84 |
+
- `--save_inter`: Save intermediate results (default: False)
|
85 |
+
- `--num_frames`: Number of frames to generate (default: 50)
|
86 |
+
- `--duration`: Duration of each frame (default: 50)
|
87 |
+
|
88 |
+
Examples:
|
89 |
+
```bash
|
90 |
+
python main.py \
|
91 |
+
--image_path_0 ./assets/Trump.jpg --image_path_1 ./assets/Biden.jpg \
|
92 |
+
--prompt_0 "A photo of an American man" --prompt_1 "A photo of an American man" \
|
93 |
+
--output_path "./results/Trump_Biden" \
|
94 |
+
--use_adain --use_reschedule --save_inter
|
95 |
+
```
|
96 |
+
|
97 |
+
```bash
|
98 |
+
python main.py \
|
99 |
+
--image_path_0 ./assets/vangogh.jpg --image_path_1 ./assets/pearlgirl.jpg \
|
100 |
+
--prompt_0 "An oil painting of a man" --prompt_1 "An oil painting of a woman" \
|
101 |
+
--output_path "./results/vangogh_pearlgirl" \
|
102 |
+
--use_adain --use_reschedule --save_inter
|
103 |
+
```
|
104 |
+
|
105 |
+
```bash
|
106 |
+
python main.py \
|
107 |
+
--image_path_0 ./assets/lion.png --image_path_1 ./assets/tiger.png \
|
108 |
+
--prompt_0 "A photo of a lion" --prompt_1 "A photo of a tiger" \
|
109 |
+
--output_path "./results/lion_tiger" \
|
110 |
+
--use_adain --use_reschedule --save_inter
|
111 |
+
```
|
112 |
+
|
113 |
+
## MorphBench
|
114 |
+
To evaluate the effectiveness of our methods, we present *MorphBench*, the first benchmark dataset for assessing image morphing of general objects. You can download the dataset from [Google Drive](https://drive.google.com/file/d/1NWPzJhOgP-udP_wYbd0selRG4cu8xsu4/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1J3xE3OJdEhKyoc1QObyYaA?pwd=putk).
|
115 |
+
|
116 |
+
|
117 |
+
## License
|
118 |
+
The code related to the DiffMorpher algorithm is licensed under [LICENSE](LICENSE.txt).
|
119 |
+
|
120 |
+
However, this project is mostly built on the open-sourse library [diffusers](https://github.com/huggingface/diffusers), which is under a separate license terms [Apache License 2.0](https://github.com/huggingface/diffusers/blob/main/LICENSE). (Cheers to the community as well!)
|
121 |
+
|
122 |
+
## Citation
|
123 |
+
|
124 |
+
```bibtex
|
125 |
+
@article{zhang2023diffmorpher,
|
126 |
+
title={DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing},
|
127 |
+
author={Zhang, Kaiwen and Zhou, Yifan and Xu, Xudong and Pan, Xingang and Dai, Bo},
|
128 |
+
journal={arXiv preprint arXiv:2312.07409},
|
129 |
+
year={2023}
|
130 |
+
}
|
131 |
+
```
|
app.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import gradio as gr
|
6 |
+
from PIL import Image
|
7 |
+
from datetime import datetime
|
8 |
+
from model import DiffMorpherPipeline
|
9 |
+
from utils.lora_utils import train_lora
|
10 |
+
|
11 |
+
LENGTH=450
|
12 |
+
|
13 |
+
def train_lora_interface(
|
14 |
+
image,
|
15 |
+
prompt,
|
16 |
+
model_path,
|
17 |
+
output_path,
|
18 |
+
lora_steps,
|
19 |
+
lora_rank,
|
20 |
+
lora_lr,
|
21 |
+
num
|
22 |
+
):
|
23 |
+
os.makedirs(output_path, exist_ok=True)
|
24 |
+
train_lora(image, prompt, output_path, model_path,
|
25 |
+
lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_{num}.ckpt", progress=gr.Progress())
|
26 |
+
return f"Train LoRA {'A' if num == 0 else 'B'} Done!"
|
27 |
+
|
28 |
+
def run_diffmorpher(
|
29 |
+
image_0,
|
30 |
+
image_1,
|
31 |
+
prompt_0,
|
32 |
+
prompt_1,
|
33 |
+
model_path,
|
34 |
+
lora_mode,
|
35 |
+
lamb,
|
36 |
+
use_adain,
|
37 |
+
use_reschedule,
|
38 |
+
num_frames,
|
39 |
+
fps,
|
40 |
+
save_inter,
|
41 |
+
load_lora_path_0,
|
42 |
+
load_lora_path_1,
|
43 |
+
output_path
|
44 |
+
):
|
45 |
+
run_id = datetime.now().strftime("%H%M") + "_" + datetime.now().strftime("%Y%m%d")
|
46 |
+
os.makedirs(output_path, exist_ok=True)
|
47 |
+
morpher_pipeline = DiffMorpherPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cuda")
|
48 |
+
if lora_mode == "Fix LoRA A":
|
49 |
+
fix_lora = 0
|
50 |
+
elif lora_mode == "Fix LoRA B":
|
51 |
+
fix_lora = 1
|
52 |
+
else:
|
53 |
+
fix_lora = None
|
54 |
+
if not load_lora_path_0:
|
55 |
+
load_lora_path_0 = f"{output_path}/lora_0.ckpt"
|
56 |
+
if not load_lora_path_1:
|
57 |
+
load_lora_path_1 = f"{output_path}/lora_1.ckpt"
|
58 |
+
images = morpher_pipeline(
|
59 |
+
img_0=image_0,
|
60 |
+
img_1=image_1,
|
61 |
+
prompt_0=prompt_0,
|
62 |
+
prompt_1=prompt_1,
|
63 |
+
load_lora_path_0=load_lora_path_0,
|
64 |
+
load_lora_path_1=load_lora_path_1,
|
65 |
+
lamb=lamb,
|
66 |
+
use_adain=use_adain,
|
67 |
+
use_reschedule=use_reschedule,
|
68 |
+
num_frames=num_frames,
|
69 |
+
fix_lora=fix_lora,
|
70 |
+
save_intermediates=save_inter,
|
71 |
+
progress=gr.Progress()
|
72 |
+
)
|
73 |
+
video_path = f"{output_path}/{run_id}.mp4"
|
74 |
+
video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (512, 512))
|
75 |
+
for i, image in enumerate(images):
|
76 |
+
# image.save(f"{output_path}/{i}.png")
|
77 |
+
video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
|
78 |
+
video.release()
|
79 |
+
cv2.destroyAllWindows()
|
80 |
+
return gr.Video(value=video_path, format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False)
|
81 |
+
|
82 |
+
def run_all(
|
83 |
+
image_0,
|
84 |
+
image_1,
|
85 |
+
prompt_0,
|
86 |
+
prompt_1,
|
87 |
+
model_path,
|
88 |
+
lora_mode,
|
89 |
+
lamb,
|
90 |
+
use_adain,
|
91 |
+
use_reschedule,
|
92 |
+
num_frames,
|
93 |
+
fps,
|
94 |
+
save_inter,
|
95 |
+
load_lora_path_0,
|
96 |
+
load_lora_path_1,
|
97 |
+
output_path,
|
98 |
+
lora_steps,
|
99 |
+
lora_rank,
|
100 |
+
lora_lr
|
101 |
+
):
|
102 |
+
os.makedirs(output_path, exist_ok=True)
|
103 |
+
train_lora(image_0, prompt_0, output_path, model_path,
|
104 |
+
lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_0.ckpt", progress=gr.Progress())
|
105 |
+
train_lora(image_1, prompt_1, output_path, model_path,
|
106 |
+
lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_1.ckpt", progress=gr.Progress())
|
107 |
+
return run_diffmorpher(
|
108 |
+
image_0,
|
109 |
+
image_1,
|
110 |
+
prompt_0,
|
111 |
+
prompt_1,
|
112 |
+
model_path,
|
113 |
+
lora_mode,
|
114 |
+
lamb,
|
115 |
+
use_adain,
|
116 |
+
use_reschedule,
|
117 |
+
num_frames,
|
118 |
+
fps,
|
119 |
+
save_inter,
|
120 |
+
load_lora_path_0,
|
121 |
+
load_lora_path_1,
|
122 |
+
output_path
|
123 |
+
)
|
124 |
+
|
125 |
+
with gr.Blocks() as demo:
|
126 |
+
|
127 |
+
with gr.Row():
|
128 |
+
gr.Markdown("""
|
129 |
+
# Official Implementation of [DiffMorpher](https://kevin-thu.github.io/DiffMorpher_page/)
|
130 |
+
""")
|
131 |
+
|
132 |
+
original_image_0, original_image_1 = gr.State(Image.open("assets/Trump.jpg").convert("RGB").resize((512,512), Image.BILINEAR)), gr.State(Image.open("assets/Biden.jpg").convert("RGB").resize((512,512), Image.BILINEAR))
|
133 |
+
# key_points_0, key_points_1 = gr.State([]), gr.State([])
|
134 |
+
# to_change_points = gr.State([])
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
with gr.Column():
|
138 |
+
input_img_0 = gr.Image(type="numpy", label="Input image A", value="assets/Trump.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
|
139 |
+
prompt_0 = gr.Textbox(label="Prompt for image A", value="a photo of an American man", interactive=True)
|
140 |
+
with gr.Row():
|
141 |
+
train_lora_0_button = gr.Button("Train LoRA A")
|
142 |
+
train_lora_1_button = gr.Button("Train LoRA B")
|
143 |
+
# show_correspond_button = gr.Button("Show correspondence points")
|
144 |
+
with gr.Column():
|
145 |
+
input_img_1 = gr.Image(type="numpy", label="Input image B ", value="assets/Biden.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
|
146 |
+
prompt_1 = gr.Textbox(label="Prompt for image B", value="a photo of an American man", interactive=True)
|
147 |
+
with gr.Row():
|
148 |
+
clear_button = gr.Button("Clear All")
|
149 |
+
run_button = gr.Button("Run w/o LoRA training")
|
150 |
+
with gr.Column():
|
151 |
+
output_video = gr.Video(format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False)
|
152 |
+
lora_progress_bar = gr.Textbox(label="Display LoRA training progress", interactive=False)
|
153 |
+
run_all_button = gr.Button("Run!")
|
154 |
+
# with gr.Column():
|
155 |
+
# output_video = gr.Video(label="Output video", show_label=True, height=LENGTH, width=LENGTH)
|
156 |
+
|
157 |
+
with gr.Row():
|
158 |
+
gr.Markdown("""
|
159 |
+
### Usage:
|
160 |
+
1. Upload two images (with correspondence) and fill out the prompts.
|
161 |
+
(It's recommended to change `[Output path]` accordingly.)
|
162 |
+
2. Click **"Run!"**
|
163 |
+
|
164 |
+
Or:
|
165 |
+
1. Upload two images (with correspondence) and fill out the prompts.
|
166 |
+
2. Click the **"Train LoRA A/B"** button to fit two LoRAs for two images respectively. <br>
|
167 |
+
If you have trained LoRA A or LoRA B before, you can skip the step and fill the specific LoRA path in LoRA settings. <br>
|
168 |
+
Trained LoRAs are saved to `[Output Path]/lora_0.ckpt` and `[Output Path]/lora_1.ckpt` by default.
|
169 |
+
3. You might also change the settings below.
|
170 |
+
4. Click **"Run w/o LoRA training"**
|
171 |
+
|
172 |
+
### Note:
|
173 |
+
1. To speed up the generation process, you can **ruduce the number of frames** or **turn off "Use Reschedule"**.
|
174 |
+
2. You can try the influence of different prompts. It seems that using the same prompts or aligned prompts works better.
|
175 |
+
### Have fun!
|
176 |
+
""")
|
177 |
+
|
178 |
+
with gr.Accordion(label="Algorithm Parameters"):
|
179 |
+
with gr.Tab("Basic Settings"):
|
180 |
+
with gr.Row():
|
181 |
+
# local_models_dir = 'local_pretrained_models'
|
182 |
+
# local_models_choice = \
|
183 |
+
# [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
|
184 |
+
model_path = gr.Text(value="stabilityai/stable-diffusion-2-1-base",
|
185 |
+
label="Diffusion Model Path", interactive=True
|
186 |
+
)
|
187 |
+
lamb = gr.Slider(value=0.6, minimum=0, maximum=1, step=0.1, label="Lambda for attention replacement", interactive=True)
|
188 |
+
lora_mode = gr.Dropdown(value="LoRA Interp",
|
189 |
+
label="LoRA Interp. or Fix LoRA",
|
190 |
+
choices=["LoRA Interp", "Fix LoRA A", "Fix LoRA B"],
|
191 |
+
interactive=True
|
192 |
+
)
|
193 |
+
use_adain = gr.Checkbox(value=True, label="Use AdaIN", interactive=True)
|
194 |
+
use_reschedule = gr.Checkbox(value=True, label="Use Reschedule", interactive=True)
|
195 |
+
with gr.Row():
|
196 |
+
num_frames = gr.Number(value=16, minimum=0, label="Number of Frames", precision=0, interactive=True)
|
197 |
+
fps = gr.Number(value=8, minimum=0, label="FPS (Frame rate)", precision=0, interactive=True)
|
198 |
+
save_inter = gr.Checkbox(value=False, label="Save Intermediate Images", interactive=True)
|
199 |
+
output_path = gr.Text(value="./results", label="Output Path", interactive=True)
|
200 |
+
|
201 |
+
with gr.Tab("LoRA Settings"):
|
202 |
+
with gr.Row():
|
203 |
+
lora_steps = gr.Number(value=200, label="LoRA training steps", precision=0, interactive=True)
|
204 |
+
lora_lr = gr.Number(value=0.0002, label="LoRA learning rate", interactive=True)
|
205 |
+
lora_rank = gr.Number(value=16, label="LoRA rank", precision=0, interactive=True)
|
206 |
+
# save_lora_dir = gr.Text(value="./lora", label="LoRA model save path", interactive=True)
|
207 |
+
load_lora_path_0 = gr.Text(value="", label="LoRA model load path for image A", interactive=True)
|
208 |
+
load_lora_path_1 = gr.Text(value="", label="LoRA model load path for image B", interactive=True)
|
209 |
+
|
210 |
+
def store_img(img):
|
211 |
+
image = Image.fromarray(img).convert("RGB").resize((512,512), Image.BILINEAR)
|
212 |
+
# resize the input to 512x512
|
213 |
+
# image = image.resize((512,512), Image.BILINEAR)
|
214 |
+
# image = np.array(image)
|
215 |
+
# when new image is uploaded, `selected_points` should be empty
|
216 |
+
return image
|
217 |
+
input_img_0.upload(
|
218 |
+
store_img,
|
219 |
+
[input_img_0],
|
220 |
+
[original_image_0]
|
221 |
+
)
|
222 |
+
input_img_1.upload(
|
223 |
+
store_img,
|
224 |
+
[input_img_1],
|
225 |
+
[original_image_1]
|
226 |
+
)
|
227 |
+
|
228 |
+
def clear(LENGTH):
|
229 |
+
return gr.Image.update(value=None, width=LENGTH, height=LENGTH), \
|
230 |
+
gr.Image.update(value=None, width=LENGTH, height=LENGTH), \
|
231 |
+
None, None, None, None
|
232 |
+
clear_button.click(
|
233 |
+
clear,
|
234 |
+
[gr.Number(value=LENGTH, visible=False, precision=0)],
|
235 |
+
[input_img_0, input_img_1, original_image_0, original_image_1, prompt_0, prompt_1]
|
236 |
+
)
|
237 |
+
|
238 |
+
train_lora_0_button.click(
|
239 |
+
train_lora_interface,
|
240 |
+
[
|
241 |
+
original_image_0,
|
242 |
+
prompt_0,
|
243 |
+
model_path,
|
244 |
+
output_path,
|
245 |
+
lora_steps,
|
246 |
+
lora_rank,
|
247 |
+
lora_lr,
|
248 |
+
gr.Number(value=0, visible=False, precision=0)
|
249 |
+
],
|
250 |
+
[lora_progress_bar]
|
251 |
+
)
|
252 |
+
|
253 |
+
train_lora_1_button.click(
|
254 |
+
train_lora_interface,
|
255 |
+
[
|
256 |
+
original_image_1,
|
257 |
+
prompt_1,
|
258 |
+
model_path,
|
259 |
+
output_path,
|
260 |
+
lora_steps,
|
261 |
+
lora_rank,
|
262 |
+
lora_lr,
|
263 |
+
gr.Number(value=1, visible=False, precision=0)
|
264 |
+
],
|
265 |
+
[lora_progress_bar]
|
266 |
+
)
|
267 |
+
|
268 |
+
run_button.click(
|
269 |
+
run_diffmorpher,
|
270 |
+
[
|
271 |
+
original_image_0,
|
272 |
+
original_image_1,
|
273 |
+
prompt_0,
|
274 |
+
prompt_1,
|
275 |
+
model_path,
|
276 |
+
lora_mode,
|
277 |
+
lamb,
|
278 |
+
use_adain,
|
279 |
+
use_reschedule,
|
280 |
+
num_frames,
|
281 |
+
fps,
|
282 |
+
save_inter,
|
283 |
+
load_lora_path_0,
|
284 |
+
load_lora_path_1,
|
285 |
+
output_path
|
286 |
+
],
|
287 |
+
[output_video]
|
288 |
+
)
|
289 |
+
|
290 |
+
run_all_button.click(
|
291 |
+
run_all,
|
292 |
+
[
|
293 |
+
original_image_0,
|
294 |
+
original_image_1,
|
295 |
+
prompt_0,
|
296 |
+
prompt_1,
|
297 |
+
model_path,
|
298 |
+
lora_mode,
|
299 |
+
lamb,
|
300 |
+
use_adain,
|
301 |
+
use_reschedule,
|
302 |
+
num_frames,
|
303 |
+
fps,
|
304 |
+
save_inter,
|
305 |
+
load_lora_path_0,
|
306 |
+
load_lora_path_1,
|
307 |
+
output_path,
|
308 |
+
lora_steps,
|
309 |
+
lora_rank,
|
310 |
+
lora_lr
|
311 |
+
],
|
312 |
+
[output_video]
|
313 |
+
)
|
314 |
+
|
315 |
+
demo.queue().launch(debug=True)
|
assets/Biden.jpg
ADDED
assets/Feifei.jpg
ADDED
assets/Musk.jpg
ADDED
assets/Teaser.png
ADDED
Git LFS Details
|
assets/Trump.jpg
ADDED
assets/cat.png
ADDED
assets/dog.png
ADDED
Git LFS Details
|
assets/dog_sit.png
ADDED
Git LFS Details
|
assets/drag_realgirl0.png
ADDED
assets/drag_realgirl1.png
ADDED
assets/drag_sculp0.png
ADDED
assets/drag_sculp1.png
ADDED
assets/fuji_0.jpg
ADDED
assets/fuji_1.jpg
ADDED
assets/house0.jpg
ADDED
assets/house1.jpg
ADDED
assets/jeep.jpg
ADDED
assets/leo_0.jpg
ADDED
assets/leo_1.jpg
ADDED
assets/lion.png
ADDED
assets/man_paint.png
ADDED
assets/mit.jpg
ADDED
assets/monalisa.jpeg
ADDED
assets/obama.jpg
ADDED
assets/pearlgirl.jpg
ADDED
assets/rabbit.png
ADDED
assets/sculp0.png
ADDED
assets/sculp1.png
ADDED
assets/teaser.gif
ADDED
Git LFS Details
|
assets/thu.jpg
ADDED
assets/tiger.png
ADDED
assets/van.jpg
ADDED
assets/vangogh.jpg
ADDED
assets/vangogh_hat.png
ADDED
assets/wave_paint.png
ADDED
assets/wave_real.jpg
ADDED
main.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from model import DiffMorpherPipeline
|
8 |
+
|
9 |
+
parser = ArgumentParser()
|
10 |
+
parser.add_argument(
|
11 |
+
"--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base",
|
12 |
+
help="Pretrained model to use (default: %(default)s)"
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--image_path_0", type=str, default="",
|
16 |
+
help="Path of the first image (default: %(default)s)")
|
17 |
+
parser.add_argument(
|
18 |
+
"--prompt_0", type=str, default="",
|
19 |
+
help="Prompt of the second image (default: %(default)s)")
|
20 |
+
parser.add_argument(
|
21 |
+
"--image_path_1", type=str, default="",
|
22 |
+
help="Path of the first image (default: %(default)s)")
|
23 |
+
parser.add_argument(
|
24 |
+
"--prompt_1", type=str, default="",
|
25 |
+
help="Prompt of the second image (default: %(default)s)")
|
26 |
+
parser.add_argument(
|
27 |
+
"--output_path", type=str, default="./results",
|
28 |
+
help="Path of the output image (default: %(default)s)"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--save_lora_dir", type=str, default="./lora",
|
32 |
+
help="Path of the output lora directory (default: %(default)s)"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--load_lora_path_0", type=str, default="",
|
36 |
+
help="Path of the lora directory of the first image (default: %(default)s)"
|
37 |
+
)
|
38 |
+
parser.add_argument(
|
39 |
+
"--load_lora_path_1", type=str, default="",
|
40 |
+
help="Path of the lora directory of the second image (default: %(default)s)"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--use_adain", action="store_true",
|
44 |
+
help="Use AdaIN (default: %(default)s)"
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--use_reschedule", action="store_true",
|
48 |
+
help="Use reschedule sampling (default: %(default)s)"
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--lamb", type=float, default=0.6,
|
52 |
+
help="Lambda for self-attention replacement (default: %(default)s)"
|
53 |
+
)
|
54 |
+
parser.add_argument(
|
55 |
+
"--fix_lora_value", type=float, default=None,
|
56 |
+
help="Fix lora value (default: LoRA Interp., not fixed)"
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--save_inter", action="store_true",
|
60 |
+
help="Save intermediate results (default: %(default)s)"
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--num_frames", type=int, default=16,
|
64 |
+
help="Number of frames to generate (default: %(default)s)"
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--duration", type=int, default=100,
|
68 |
+
help="Duration of each frame (default: %(default)s ms)"
|
69 |
+
)
|
70 |
+
parser.add_argument(
|
71 |
+
"--no_lora", action="store_true"
|
72 |
+
)
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
|
76 |
+
os.makedirs(args.output_path, exist_ok=True)
|
77 |
+
pipeline = DiffMorpherPipeline.from_pretrained(
|
78 |
+
args.model_path, torch_dtype=torch.float32)
|
79 |
+
pipeline.to("cuda")
|
80 |
+
images = pipeline(
|
81 |
+
img_path_0=args.image_path_0,
|
82 |
+
img_path_1=args.image_path_1,
|
83 |
+
prompt_0=args.prompt_0,
|
84 |
+
prompt_1=args.prompt_1,
|
85 |
+
save_lora_dir=args.save_lora_dir,
|
86 |
+
load_lora_path_0=args.load_lora_path_0,
|
87 |
+
load_lora_path_1=args.load_lora_path_1,
|
88 |
+
use_adain=args.use_adain,
|
89 |
+
use_reschedule=args.use_reschedule,
|
90 |
+
lamd=args.lamb,
|
91 |
+
output_path=args.output_path,
|
92 |
+
num_frames=args.num_frames,
|
93 |
+
fix_lora=args.fix_lora_value,
|
94 |
+
save_intermediates=args.save_inter,
|
95 |
+
use_lora=not args.no_lora
|
96 |
+
)
|
97 |
+
images[0].save(f"{args.output_path}/output.gif", save_all=True,
|
98 |
+
append_images=images[1:], duration=args.duration, loop=0)
|
model.py
ADDED
@@ -0,0 +1,639 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
3 |
+
from diffusers.models.attention_processor import AttnProcessor
|
4 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
5 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import safetensors
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
14 |
+
from diffusers import StableDiffusionPipeline
|
15 |
+
from argparse import ArgumentParser
|
16 |
+
import inspect
|
17 |
+
|
18 |
+
from utils.model_utils import get_img, slerp, do_replace_attn
|
19 |
+
from utils.lora_utils import train_lora, load_lora
|
20 |
+
from utils.alpha_scheduler import AlphaScheduler
|
21 |
+
|
22 |
+
|
23 |
+
class StoreProcessor():
|
24 |
+
def __init__(self, original_processor, value_dict, name):
|
25 |
+
self.original_processor = original_processor
|
26 |
+
self.value_dict = value_dict
|
27 |
+
self.name = name
|
28 |
+
self.value_dict[self.name] = dict()
|
29 |
+
self.id = 0
|
30 |
+
|
31 |
+
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
32 |
+
# Is self attention
|
33 |
+
if encoder_hidden_states is None:
|
34 |
+
self.value_dict[self.name][self.id] = hidden_states.detach()
|
35 |
+
self.id += 1
|
36 |
+
res = self.original_processor(attn, hidden_states, *args,
|
37 |
+
encoder_hidden_states=encoder_hidden_states,
|
38 |
+
attention_mask=attention_mask,
|
39 |
+
**kwargs)
|
40 |
+
|
41 |
+
return res
|
42 |
+
|
43 |
+
|
44 |
+
class LoadProcessor():
|
45 |
+
def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6):
|
46 |
+
super().__init__()
|
47 |
+
self.original_processor = original_processor
|
48 |
+
self.name = name
|
49 |
+
self.img0_dict = img0_dict
|
50 |
+
self.img1_dict = img1_dict
|
51 |
+
self.alpha = alpha
|
52 |
+
self.beta = beta
|
53 |
+
self.lamd = lamd
|
54 |
+
self.id = 0
|
55 |
+
|
56 |
+
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
57 |
+
# Is self attention
|
58 |
+
if encoder_hidden_states is None:
|
59 |
+
if self.id < 50 * self.lamd:
|
60 |
+
map0 = self.img0_dict[self.name][self.id]
|
61 |
+
map1 = self.img1_dict[self.name][self.id]
|
62 |
+
cross_map = self.beta * hidden_states + \
|
63 |
+
(1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
|
64 |
+
# cross_map = self.beta * hidden_states + \
|
65 |
+
# (1 - self.beta) * slerp(map0, map1, self.alpha)
|
66 |
+
# cross_map = slerp(slerp(map0, map1, self.alpha),
|
67 |
+
# hidden_states, self.beta)
|
68 |
+
# cross_map = hidden_states
|
69 |
+
# cross_map = torch.cat(
|
70 |
+
# ((1 - self.alpha) * map0, self.alpha * map1), dim=1)
|
71 |
+
|
72 |
+
res = self.original_processor(attn, hidden_states, *args,
|
73 |
+
encoder_hidden_states=cross_map,
|
74 |
+
attention_mask=attention_mask,
|
75 |
+
**kwargs)
|
76 |
+
else:
|
77 |
+
res = self.original_processor(attn, hidden_states, *args,
|
78 |
+
encoder_hidden_states=encoder_hidden_states,
|
79 |
+
attention_mask=attention_mask,
|
80 |
+
**kwargs)
|
81 |
+
|
82 |
+
self.id += 1
|
83 |
+
# if self.id == len(self.img0_dict[self.name]):
|
84 |
+
if self.id == len(self.img0_dict[self.name]):
|
85 |
+
self.id = 0
|
86 |
+
else:
|
87 |
+
res = self.original_processor(attn, hidden_states, *args,
|
88 |
+
encoder_hidden_states=encoder_hidden_states,
|
89 |
+
attention_mask=attention_mask,
|
90 |
+
**kwargs)
|
91 |
+
|
92 |
+
return res
|
93 |
+
|
94 |
+
|
95 |
+
class DiffMorpherPipeline(StableDiffusionPipeline):
|
96 |
+
|
97 |
+
def __init__(self,
|
98 |
+
vae: AutoencoderKL,
|
99 |
+
text_encoder: CLIPTextModel,
|
100 |
+
tokenizer: CLIPTokenizer,
|
101 |
+
unet: UNet2DConditionModel,
|
102 |
+
scheduler: KarrasDiffusionSchedulers,
|
103 |
+
safety_checker: StableDiffusionSafetyChecker,
|
104 |
+
feature_extractor: CLIPImageProcessor,
|
105 |
+
image_encoder=None,
|
106 |
+
requires_safety_checker: bool = True,
|
107 |
+
):
|
108 |
+
sig = inspect.signature(super().__init__)
|
109 |
+
params = sig.parameters
|
110 |
+
if 'image_encoder' in params:
|
111 |
+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
112 |
+
safety_checker, feature_extractor, image_encoder, requires_safety_checker)
|
113 |
+
else:
|
114 |
+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
115 |
+
safety_checker, feature_extractor, requires_safety_checker)
|
116 |
+
self.img0_dict = dict()
|
117 |
+
self.img1_dict = dict()
|
118 |
+
|
119 |
+
def inv_step(
|
120 |
+
self,
|
121 |
+
model_output: torch.FloatTensor,
|
122 |
+
timestep: int,
|
123 |
+
x: torch.FloatTensor,
|
124 |
+
eta=0.,
|
125 |
+
verbose=False
|
126 |
+
):
|
127 |
+
"""
|
128 |
+
Inverse sampling for DDIM Inversion
|
129 |
+
"""
|
130 |
+
if verbose:
|
131 |
+
print("timestep: ", timestep)
|
132 |
+
next_step = timestep
|
133 |
+
timestep = min(timestep - self.scheduler.config.num_train_timesteps //
|
134 |
+
self.scheduler.num_inference_steps, 999)
|
135 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[
|
136 |
+
timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
137 |
+
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
|
138 |
+
beta_prod_t = 1 - alpha_prod_t
|
139 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
140 |
+
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
|
141 |
+
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
|
142 |
+
return x_next, pred_x0
|
143 |
+
|
144 |
+
@torch.no_grad()
|
145 |
+
def invert(
|
146 |
+
self,
|
147 |
+
image: torch.Tensor,
|
148 |
+
prompt,
|
149 |
+
num_inference_steps=50,
|
150 |
+
num_actual_inference_steps=None,
|
151 |
+
guidance_scale=1.,
|
152 |
+
eta=0.0,
|
153 |
+
**kwds):
|
154 |
+
"""
|
155 |
+
invert a real image into noise map with determinisc DDIM inversion
|
156 |
+
"""
|
157 |
+
DEVICE = torch.device(
|
158 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
159 |
+
batch_size = image.shape[0]
|
160 |
+
if isinstance(prompt, list):
|
161 |
+
if batch_size == 1:
|
162 |
+
image = image.expand(len(prompt), -1, -1, -1)
|
163 |
+
elif isinstance(prompt, str):
|
164 |
+
if batch_size > 1:
|
165 |
+
prompt = [prompt] * batch_size
|
166 |
+
|
167 |
+
# text embeddings
|
168 |
+
text_input = self.tokenizer(
|
169 |
+
prompt,
|
170 |
+
padding="max_length",
|
171 |
+
max_length=77,
|
172 |
+
return_tensors="pt"
|
173 |
+
)
|
174 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
175 |
+
print("input text embeddings :", text_embeddings.shape)
|
176 |
+
# define initial latents
|
177 |
+
latents = self.image2latent(image)
|
178 |
+
|
179 |
+
# unconditional embedding for classifier free guidance
|
180 |
+
if guidance_scale > 1.:
|
181 |
+
max_length = text_input.input_ids.shape[-1]
|
182 |
+
unconditional_input = self.tokenizer(
|
183 |
+
[""] * batch_size,
|
184 |
+
padding="max_length",
|
185 |
+
max_length=77,
|
186 |
+
return_tensors="pt"
|
187 |
+
)
|
188 |
+
unconditional_embeddings = self.text_encoder(
|
189 |
+
unconditional_input.input_ids.to(DEVICE))[0]
|
190 |
+
text_embeddings = torch.cat(
|
191 |
+
[unconditional_embeddings, text_embeddings], dim=0)
|
192 |
+
|
193 |
+
print("latents shape: ", latents.shape)
|
194 |
+
# interative sampling
|
195 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
196 |
+
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
197 |
+
# print("attributes: ", self.scheduler.__dict__)
|
198 |
+
latents_list = [latents]
|
199 |
+
pred_x0_list = [latents]
|
200 |
+
for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
|
201 |
+
if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
|
202 |
+
continue
|
203 |
+
|
204 |
+
if guidance_scale > 1.:
|
205 |
+
model_inputs = torch.cat([latents] * 2)
|
206 |
+
else:
|
207 |
+
model_inputs = latents
|
208 |
+
|
209 |
+
# predict the noise
|
210 |
+
noise_pred = self.unet(
|
211 |
+
model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
212 |
+
if guidance_scale > 1.:
|
213 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
214 |
+
noise_pred = noise_pred_uncon + guidance_scale * \
|
215 |
+
(noise_pred_con - noise_pred_uncon)
|
216 |
+
# compute the previous noise sample x_t-1 -> x_t
|
217 |
+
latents, pred_x0 = self.inv_step(noise_pred, t, latents)
|
218 |
+
latents_list.append(latents)
|
219 |
+
pred_x0_list.append(pred_x0)
|
220 |
+
|
221 |
+
return latents
|
222 |
+
|
223 |
+
@torch.no_grad()
|
224 |
+
def ddim_inversion(self, latent, cond):
|
225 |
+
timesteps = reversed(self.scheduler.timesteps)
|
226 |
+
with torch.autocast(device_type='cuda', dtype=torch.float32):
|
227 |
+
for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")):
|
228 |
+
cond_batch = cond.repeat(latent.shape[0], 1, 1)
|
229 |
+
|
230 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
231 |
+
alpha_prod_t_prev = (
|
232 |
+
self.scheduler.alphas_cumprod[timesteps[i - 1]]
|
233 |
+
if i > 0 else self.scheduler.final_alpha_cumprod
|
234 |
+
)
|
235 |
+
|
236 |
+
mu = alpha_prod_t ** 0.5
|
237 |
+
mu_prev = alpha_prod_t_prev ** 0.5
|
238 |
+
sigma = (1 - alpha_prod_t) ** 0.5
|
239 |
+
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
|
240 |
+
|
241 |
+
eps = self.unet(
|
242 |
+
latent, t, encoder_hidden_states=cond_batch).sample
|
243 |
+
|
244 |
+
pred_x0 = (latent - sigma_prev * eps) / mu_prev
|
245 |
+
latent = mu * pred_x0 + sigma * eps
|
246 |
+
# if save_latents:
|
247 |
+
# torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
|
248 |
+
# torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
|
249 |
+
return latent
|
250 |
+
|
251 |
+
def step(
|
252 |
+
self,
|
253 |
+
model_output: torch.FloatTensor,
|
254 |
+
timestep: int,
|
255 |
+
x: torch.FloatTensor,
|
256 |
+
):
|
257 |
+
"""
|
258 |
+
predict the sample of the next step in the denoise process.
|
259 |
+
"""
|
260 |
+
prev_timestep = timestep - \
|
261 |
+
self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
262 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
263 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
|
264 |
+
prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
|
265 |
+
beta_prod_t = 1 - alpha_prod_t
|
266 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
267 |
+
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
|
268 |
+
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
|
269 |
+
return x_prev, pred_x0
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def image2latent(self, image):
|
273 |
+
DEVICE = torch.device(
|
274 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
275 |
+
if type(image) is Image:
|
276 |
+
image = np.array(image)
|
277 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
278 |
+
image = image.permute(2, 0, 1).unsqueeze(0)
|
279 |
+
# input image density range [-1, 1]
|
280 |
+
latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean
|
281 |
+
latents = latents * 0.18215
|
282 |
+
return latents
|
283 |
+
|
284 |
+
@torch.no_grad()
|
285 |
+
def latent2image(self, latents, return_type='np'):
|
286 |
+
latents = 1 / 0.18215 * latents.detach()
|
287 |
+
image = self.vae.decode(latents)['sample']
|
288 |
+
if return_type == 'np':
|
289 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
290 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
291 |
+
image = (image * 255).astype(np.uint8)
|
292 |
+
elif return_type == "pt":
|
293 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
294 |
+
|
295 |
+
return image
|
296 |
+
|
297 |
+
def latent2image_grad(self, latents):
|
298 |
+
latents = 1 / 0.18215 * latents
|
299 |
+
image = self.vae.decode(latents)['sample']
|
300 |
+
|
301 |
+
return image # range [-1, 1]
|
302 |
+
|
303 |
+
@torch.no_grad()
|
304 |
+
def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None):
|
305 |
+
# latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \
|
306 |
+
# torch.sin(alpha * torch.pi / 2) * img_noise_1
|
307 |
+
# latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1
|
308 |
+
# latents = latents / ((1 - alpha) ** 2 + alpha ** 2)
|
309 |
+
latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain)
|
310 |
+
text_embeddings = (1 - alpha) * text_embeddings_0 + \
|
311 |
+
alpha * text_embeddings_1
|
312 |
+
|
313 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
314 |
+
if use_lora:
|
315 |
+
if fix_lora is not None:
|
316 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora)
|
317 |
+
else:
|
318 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, alpha)
|
319 |
+
|
320 |
+
for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")):
|
321 |
+
|
322 |
+
if guidance_scale > 1.:
|
323 |
+
model_inputs = torch.cat([latents] * 2)
|
324 |
+
else:
|
325 |
+
model_inputs = latents
|
326 |
+
if unconditioning is not None and isinstance(unconditioning, list):
|
327 |
+
_, text_embeddings = text_embeddings.chunk(2)
|
328 |
+
text_embeddings = torch.cat(
|
329 |
+
[unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
|
330 |
+
# predict the noise
|
331 |
+
noise_pred = self.unet(
|
332 |
+
model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
333 |
+
if guidance_scale > 1.0:
|
334 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(
|
335 |
+
2, dim=0)
|
336 |
+
noise_pred = noise_pred_uncon + guidance_scale * \
|
337 |
+
(noise_pred_con - noise_pred_uncon)
|
338 |
+
# compute the previous noise sample x_t -> x_t-1
|
339 |
+
latents = self.scheduler.step(
|
340 |
+
noise_pred, t, latents, return_dict=False)[0]
|
341 |
+
return latents
|
342 |
+
|
343 |
+
@torch.no_grad()
|
344 |
+
def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size):
|
345 |
+
DEVICE = torch.device(
|
346 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
347 |
+
# text embeddings
|
348 |
+
text_input = self.tokenizer(
|
349 |
+
prompt,
|
350 |
+
padding="max_length",
|
351 |
+
max_length=77,
|
352 |
+
return_tensors="pt"
|
353 |
+
)
|
354 |
+
text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0]
|
355 |
+
|
356 |
+
if guidance_scale > 1.:
|
357 |
+
if neg_prompt:
|
358 |
+
uc_text = neg_prompt
|
359 |
+
else:
|
360 |
+
uc_text = ""
|
361 |
+
unconditional_input = self.tokenizer(
|
362 |
+
[uc_text] * batch_size,
|
363 |
+
padding="max_length",
|
364 |
+
max_length=77,
|
365 |
+
return_tensors="pt"
|
366 |
+
)
|
367 |
+
unconditional_embeddings = self.text_encoder(
|
368 |
+
unconditional_input.input_ids.to(DEVICE))[0]
|
369 |
+
text_embeddings = torch.cat(
|
370 |
+
[unconditional_embeddings, text_embeddings], dim=0)
|
371 |
+
|
372 |
+
return text_embeddings
|
373 |
+
|
374 |
+
def __call__(
|
375 |
+
self,
|
376 |
+
img_0=None,
|
377 |
+
img_1=None,
|
378 |
+
img_path_0=None,
|
379 |
+
img_path_1=None,
|
380 |
+
prompt_0="",
|
381 |
+
prompt_1="",
|
382 |
+
save_lora_dir="./lora",
|
383 |
+
load_lora_path_0=None,
|
384 |
+
load_lora_path_1=None,
|
385 |
+
lora_steps=200,
|
386 |
+
lora_lr=2e-4,
|
387 |
+
lora_rank=16,
|
388 |
+
batch_size=1,
|
389 |
+
height=512,
|
390 |
+
width=512,
|
391 |
+
num_inference_steps=50,
|
392 |
+
num_actual_inference_steps=None,
|
393 |
+
guidance_scale=1,
|
394 |
+
attn_beta=0,
|
395 |
+
lamd=0.6,
|
396 |
+
use_lora=True,
|
397 |
+
use_adain=True,
|
398 |
+
use_reschedule=True,
|
399 |
+
output_path="./results",
|
400 |
+
num_frames=50,
|
401 |
+
fix_lora=None,
|
402 |
+
progress=tqdm,
|
403 |
+
unconditioning=None,
|
404 |
+
neg_prompt=None,
|
405 |
+
save_intermediates=False,
|
406 |
+
**kwds):
|
407 |
+
|
408 |
+
# if isinstance(prompt, list):
|
409 |
+
# batch_size = len(prompt)
|
410 |
+
# elif isinstance(prompt, str):
|
411 |
+
# if batch_size > 1:
|
412 |
+
# prompt = [prompt] * batch_size
|
413 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
414 |
+
self.use_lora = use_lora
|
415 |
+
self.use_adain = use_adain
|
416 |
+
self.use_reschedule = use_reschedule
|
417 |
+
self.output_path = output_path
|
418 |
+
|
419 |
+
if img_0 is None:
|
420 |
+
img_0 = Image.open(img_path_0).convert("RGB")
|
421 |
+
# else:
|
422 |
+
# img_0 = Image.fromarray(img_0).convert("RGB")
|
423 |
+
|
424 |
+
if img_1 is None:
|
425 |
+
img_1 = Image.open(img_path_1).convert("RGB")
|
426 |
+
# else:
|
427 |
+
# img_1 = Image.fromarray(img_1).convert("RGB")
|
428 |
+
|
429 |
+
if self.use_lora:
|
430 |
+
print("Loading lora...")
|
431 |
+
if not load_lora_path_0:
|
432 |
+
|
433 |
+
weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
|
434 |
+
load_lora_path_0 = save_lora_dir + "/" + weight_name
|
435 |
+
if not os.path.exists(load_lora_path_0):
|
436 |
+
train_lora(img_0, prompt_0, save_lora_dir, None, self.tokenizer, self.text_encoder,
|
437 |
+
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
438 |
+
print(f"Load from {load_lora_path_0}.")
|
439 |
+
if load_lora_path_0.endswith(".safetensors"):
|
440 |
+
lora_0 = safetensors.torch.load_file(
|
441 |
+
load_lora_path_0, device="cpu")
|
442 |
+
else:
|
443 |
+
lora_0 = torch.load(load_lora_path_0, map_location="cpu")
|
444 |
+
|
445 |
+
if not load_lora_path_1:
|
446 |
+
weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
|
447 |
+
load_lora_path_1 = save_lora_dir + "/" + weight_name
|
448 |
+
if not os.path.exists(load_lora_path_1):
|
449 |
+
train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder,
|
450 |
+
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
451 |
+
print(f"Load from {load_lora_path_1}.")
|
452 |
+
if load_lora_path_1.endswith(".safetensors"):
|
453 |
+
lora_1 = safetensors.torch.load_file(
|
454 |
+
load_lora_path_1, device="cpu")
|
455 |
+
else:
|
456 |
+
lora_1 = torch.load(load_lora_path_1, map_location="cpu")
|
457 |
+
else:
|
458 |
+
lora_0 = lora_1 = None
|
459 |
+
|
460 |
+
text_embeddings_0 = self.get_text_embeddings(
|
461 |
+
prompt_0, guidance_scale, neg_prompt, batch_size)
|
462 |
+
text_embeddings_1 = self.get_text_embeddings(
|
463 |
+
prompt_1, guidance_scale, neg_prompt, batch_size)
|
464 |
+
img_0 = get_img(img_0)
|
465 |
+
img_1 = get_img(img_1)
|
466 |
+
if self.use_lora:
|
467 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, 0)
|
468 |
+
img_noise_0 = self.ddim_inversion(
|
469 |
+
self.image2latent(img_0), text_embeddings_0)
|
470 |
+
if self.use_lora:
|
471 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, 1)
|
472 |
+
img_noise_1 = self.ddim_inversion(
|
473 |
+
self.image2latent(img_1), text_embeddings_1)
|
474 |
+
|
475 |
+
print("latents shape: ", img_noise_0.shape)
|
476 |
+
|
477 |
+
original_processor = list(self.unet.attn_processors.values())[0]
|
478 |
+
|
479 |
+
def morph(alpha_list, progress, desc):
|
480 |
+
images = []
|
481 |
+
if attn_beta is not None:
|
482 |
+
if self.use_lora:
|
483 |
+
self.unet = load_lora(
|
484 |
+
self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora)
|
485 |
+
|
486 |
+
attn_processor_dict = {}
|
487 |
+
for k in self.unet.attn_processors.keys():
|
488 |
+
if do_replace_attn(k):
|
489 |
+
if self.use_lora:
|
490 |
+
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
491 |
+
self.img0_dict, k)
|
492 |
+
else:
|
493 |
+
attn_processor_dict[k] = StoreProcessor(original_processor,
|
494 |
+
self.img0_dict, k)
|
495 |
+
else:
|
496 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
497 |
+
self.unet.set_attn_processor(attn_processor_dict)
|
498 |
+
|
499 |
+
latents = self.cal_latent(
|
500 |
+
num_inference_steps,
|
501 |
+
guidance_scale,
|
502 |
+
unconditioning,
|
503 |
+
img_noise_0,
|
504 |
+
img_noise_1,
|
505 |
+
text_embeddings_0,
|
506 |
+
text_embeddings_1,
|
507 |
+
lora_0,
|
508 |
+
lora_1,
|
509 |
+
alpha_list[0],
|
510 |
+
False,
|
511 |
+
fix_lora
|
512 |
+
)
|
513 |
+
first_image = self.latent2image(latents)
|
514 |
+
first_image = Image.fromarray(first_image)
|
515 |
+
if save_intermediates:
|
516 |
+
first_image.save(f"{self.output_path}/{0:02d}.png")
|
517 |
+
|
518 |
+
if self.use_lora:
|
519 |
+
self.unet = load_lora(
|
520 |
+
self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora)
|
521 |
+
attn_processor_dict = {}
|
522 |
+
for k in self.unet.attn_processors.keys():
|
523 |
+
if do_replace_attn(k):
|
524 |
+
if self.use_lora:
|
525 |
+
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
526 |
+
self.img1_dict, k)
|
527 |
+
else:
|
528 |
+
attn_processor_dict[k] = StoreProcessor(original_processor,
|
529 |
+
self.img1_dict, k)
|
530 |
+
else:
|
531 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
532 |
+
|
533 |
+
self.unet.set_attn_processor(attn_processor_dict)
|
534 |
+
|
535 |
+
latents = self.cal_latent(
|
536 |
+
num_inference_steps,
|
537 |
+
guidance_scale,
|
538 |
+
unconditioning,
|
539 |
+
img_noise_0,
|
540 |
+
img_noise_1,
|
541 |
+
text_embeddings_0,
|
542 |
+
text_embeddings_1,
|
543 |
+
lora_0,
|
544 |
+
lora_1,
|
545 |
+
alpha_list[-1],
|
546 |
+
False,
|
547 |
+
fix_lora
|
548 |
+
)
|
549 |
+
last_image = self.latent2image(latents)
|
550 |
+
last_image = Image.fromarray(last_image)
|
551 |
+
if save_intermediates:
|
552 |
+
last_image.save(
|
553 |
+
f"{self.output_path}/{num_frames - 1:02d}.png")
|
554 |
+
|
555 |
+
for i in progress.tqdm(range(1, num_frames - 1), desc=desc):
|
556 |
+
alpha = alpha_list[i]
|
557 |
+
if self.use_lora:
|
558 |
+
self.unet = load_lora(
|
559 |
+
self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora)
|
560 |
+
|
561 |
+
attn_processor_dict = {}
|
562 |
+
for k in self.unet.attn_processors.keys():
|
563 |
+
if do_replace_attn(k):
|
564 |
+
if self.use_lora:
|
565 |
+
attn_processor_dict[k] = LoadProcessor(
|
566 |
+
self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
|
567 |
+
else:
|
568 |
+
attn_processor_dict[k] = LoadProcessor(
|
569 |
+
original_processor, k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
|
570 |
+
else:
|
571 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
572 |
+
|
573 |
+
self.unet.set_attn_processor(attn_processor_dict)
|
574 |
+
|
575 |
+
latents = self.cal_latent(
|
576 |
+
num_inference_steps,
|
577 |
+
guidance_scale,
|
578 |
+
unconditioning,
|
579 |
+
img_noise_0,
|
580 |
+
img_noise_1,
|
581 |
+
text_embeddings_0,
|
582 |
+
text_embeddings_1,
|
583 |
+
lora_0,
|
584 |
+
lora_1,
|
585 |
+
alpha_list[i],
|
586 |
+
False,
|
587 |
+
fix_lora
|
588 |
+
)
|
589 |
+
image = self.latent2image(latents)
|
590 |
+
image = Image.fromarray(image)
|
591 |
+
if save_intermediates:
|
592 |
+
image.save(f"{self.output_path}/{i:02d}.png")
|
593 |
+
images.append(image)
|
594 |
+
|
595 |
+
images = [first_image] + images + [last_image]
|
596 |
+
|
597 |
+
else:
|
598 |
+
for k, alpha in enumerate(alpha_list):
|
599 |
+
|
600 |
+
latents = self.cal_latent(
|
601 |
+
num_inference_steps,
|
602 |
+
guidance_scale,
|
603 |
+
unconditioning,
|
604 |
+
img_noise_0,
|
605 |
+
img_noise_1,
|
606 |
+
text_embeddings_0,
|
607 |
+
text_embeddings_1,
|
608 |
+
lora_0,
|
609 |
+
lora_1,
|
610 |
+
alpha_list[k],
|
611 |
+
self.use_lora,
|
612 |
+
fix_lora
|
613 |
+
)
|
614 |
+
image = self.latent2image(latents)
|
615 |
+
image = Image.fromarray(image)
|
616 |
+
if save_intermediates:
|
617 |
+
image.save(f"{self.output_path}/{k:02d}.png")
|
618 |
+
images.append(image)
|
619 |
+
|
620 |
+
return images
|
621 |
+
|
622 |
+
with torch.no_grad():
|
623 |
+
if self.use_reschedule:
|
624 |
+
alpha_scheduler = AlphaScheduler()
|
625 |
+
alpha_list = list(torch.linspace(0, 1, num_frames))
|
626 |
+
images_pt = morph(alpha_list, progress, "Sampling...")
|
627 |
+
images_pt = [transforms.ToTensor()(img).unsqueeze(0)
|
628 |
+
for img in images_pt]
|
629 |
+
alpha_scheduler.from_imgs(images_pt)
|
630 |
+
alpha_list = alpha_scheduler.get_list()
|
631 |
+
print(alpha_list)
|
632 |
+
images = morph(alpha_list, progress, "Reschedule..."
|
633 |
+
)
|
634 |
+
else:
|
635 |
+
alpha_list = list(torch.linspace(0, 1, num_frames))
|
636 |
+
print(alpha_list)
|
637 |
+
images = morph(alpha_list, progress, "Sampling...")
|
638 |
+
|
639 |
+
return images
|
multi_image/README.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Update
|
2 |
+
|
3 |
+
Add support for multi-image input. Now you can get the morphing output among more than 2 images.
|
4 |
+
|
5 |
+
## Run the code
|
6 |
+
|
7 |
+
You can run the code with the following command:
|
8 |
+
|
9 |
+
```
|
10 |
+
python main.py \
|
11 |
+
--image_paths [image_path_0] ... [image_path_n] \
|
12 |
+
--prompts [prompt_0] ... [prompt_n] \
|
13 |
+
--output_path [output_path] \
|
14 |
+
--use_adain --use_reschedule --save_inter
|
15 |
+
```
|
16 |
+
|
17 |
+
This modification add support for the following options:
|
18 |
+
|
19 |
+
- `--image_paths`: Paths of the input images
|
20 |
+
- `--prompts`: Prompts of the images
|
21 |
+
- `--load_lora_paths`: Paths of the lora directory of the images
|
22 |
+
|
23 |
+
## Example
|
24 |
+
|
25 |
+
Run the code:
|
26 |
+
```
|
27 |
+
python main.py \
|
28 |
+
--image_paths ./assets/realdog0.jpg ./assets/realdog1.jpg ./assets/realdog2.jpg \
|
29 |
+
--prompts "A photo of a dog" "A photo of a dog" "A photo of a dog" \
|
30 |
+
--output_path "./results/dog" \
|
31 |
+
--use_adain --use_reschedule --save_inter
|
32 |
+
```
|
33 |
+
|
34 |
+
Output:
|
35 |
+
<div align="center">
|
36 |
+
<img src="assets/realdog.gif" width="50%" height="50%">
|
37 |
+
</div>
|
multi_image/assets/realdog.gif
ADDED
Git LFS Details
|
multi_image/assets/realdog0.jpg
ADDED
multi_image/assets/realdog1.jpg
ADDED
multi_image/assets/realdog2.jpg
ADDED
multi_image/main.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from model import DiffMorpherPipeline
|
8 |
+
|
9 |
+
parser = ArgumentParser()
|
10 |
+
parser.add_argument(
|
11 |
+
"--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base",
|
12 |
+
help="Pretrained model to use (default: %(default)s)"
|
13 |
+
)
|
14 |
+
parser.add_argument(
|
15 |
+
"--image_path_0", type=str, default="",
|
16 |
+
help="Path of the first image (default: %(default)s)")
|
17 |
+
parser.add_argument(
|
18 |
+
"--prompt_0", type=str, default="",
|
19 |
+
help="Prompt of the second image (default: %(default)s)")
|
20 |
+
parser.add_argument(
|
21 |
+
"--image_path_1", type=str, default="",
|
22 |
+
help="Path of the first image (default: %(default)s)")
|
23 |
+
parser.add_argument(
|
24 |
+
"--prompt_1", type=str, default="",
|
25 |
+
help="Prompt of the second image (default: %(default)s)")
|
26 |
+
parser.add_argument(
|
27 |
+
"--load_lora_path_0", type=str, default="",
|
28 |
+
help="Path of the lora directory of the first image (default: %(default)s)"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--load_lora_path_1", type=str, default="",
|
32 |
+
help="Path of the lora directory of the second image (default: %(default)s)"
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--image_paths", type=str, nargs='*', default=[],
|
36 |
+
help="Path of the first image (default: %(default)s)")
|
37 |
+
parser.add_argument(
|
38 |
+
"--prompts", type=str, nargs='*', default=[],
|
39 |
+
help="Prompt of the second image (default: %(default)s)")
|
40 |
+
parser.add_argument(
|
41 |
+
"--output_path", type=str, default="./results",
|
42 |
+
help="Path of the output image (default: %(default)s)"
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--save_lora_dir", type=str, default="./lora",
|
46 |
+
help="Path of the output lora directory (default: %(default)s)"
|
47 |
+
)
|
48 |
+
parser.add_argument(
|
49 |
+
"--load_lora_paths", type=str, nargs='*', default=[],
|
50 |
+
help="Path of the lora directory of the first image (default: %(default)s)"
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--use_adain", action="store_true",
|
54 |
+
help="Use AdaIN (default: %(default)s)"
|
55 |
+
)
|
56 |
+
parser.add_argument(
|
57 |
+
"--use_reschedule", action="store_true",
|
58 |
+
help="Use reschedule sampling (default: %(default)s)"
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--lamb", type=float, default=0.6,
|
62 |
+
help="Lambda for self-attention replacement (default: %(default)s)"
|
63 |
+
)
|
64 |
+
parser.add_argument(
|
65 |
+
"--fix_lora_value", type=float, default=None,
|
66 |
+
help="Fix lora value (default: LoRA Interp., not fixed)"
|
67 |
+
)
|
68 |
+
parser.add_argument(
|
69 |
+
"--save_inter", action="store_true",
|
70 |
+
help="Save intermediate results (default: %(default)s)"
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--num_frames", type=int, default=16,
|
74 |
+
help="Number of frames to generate (default: %(default)s)"
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--duration", type=int, default=100,
|
78 |
+
help="Duration of each frame (default: %(default)s ms)"
|
79 |
+
)
|
80 |
+
|
81 |
+
args = parser.parse_args()
|
82 |
+
|
83 |
+
os.makedirs(args.output_path, exist_ok=True)
|
84 |
+
pipeline = DiffMorpherPipeline.from_pretrained(
|
85 |
+
args.model_path, torch_dtype=torch.float32)
|
86 |
+
pipeline.to("cuda")
|
87 |
+
images = pipeline(
|
88 |
+
img_path_0=args.image_path_0,
|
89 |
+
img_path_1=args.image_path_1,
|
90 |
+
prompt_0=args.prompt_0,
|
91 |
+
prompt_1=args.prompt_1,
|
92 |
+
load_lora_path_0=args.load_lora_path_0,
|
93 |
+
load_lora_path_1=args.load_lora_path_1,
|
94 |
+
img_paths=args.image_paths,
|
95 |
+
prompts=args.prompts,
|
96 |
+
save_lora_dir=args.save_lora_dir,
|
97 |
+
load_lora_paths=args.load_lora_paths,
|
98 |
+
use_adain=args.use_adain,
|
99 |
+
use_reschedule=args.use_reschedule,
|
100 |
+
lamb=args.lamb,
|
101 |
+
output_path=args.output_path,
|
102 |
+
num_frames=args.num_frames,
|
103 |
+
fix_lora=args.fix_lora_value,
|
104 |
+
save_intermediates=args.save_inter,
|
105 |
+
)
|
106 |
+
images[0].save(f"{args.output_path}/output.gif", save_all=True,
|
107 |
+
append_images=images[1:], duration=args.duration, loop=0)
|
multi_image/model.py
ADDED
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
3 |
+
from diffusers.models.attention_processor import AttnProcessor
|
4 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
5 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import safetensors
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
14 |
+
from diffusers import StableDiffusionPipeline
|
15 |
+
from argparse import ArgumentParser
|
16 |
+
|
17 |
+
|
18 |
+
from utils.model_utils import get_img, slerp, do_replace_attn
|
19 |
+
from utils.lora_utils import train_lora, load_lora
|
20 |
+
from utils.alpha_scheduler import AlphaScheduler
|
21 |
+
|
22 |
+
class StoreProcessor():
|
23 |
+
def __init__(self, original_processor, value_dict, name):
|
24 |
+
self.original_processor = original_processor
|
25 |
+
self.value_dict = value_dict
|
26 |
+
self.name = name
|
27 |
+
self.value_dict[self.name] = dict()
|
28 |
+
self.id = 0
|
29 |
+
|
30 |
+
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
31 |
+
# Is self attention
|
32 |
+
if encoder_hidden_states is None:
|
33 |
+
self.value_dict[self.name][self.id] = hidden_states.detach()
|
34 |
+
self.id += 1
|
35 |
+
res = self.original_processor(attn, hidden_states, *args,
|
36 |
+
encoder_hidden_states=encoder_hidden_states,
|
37 |
+
attention_mask=attention_mask,
|
38 |
+
**kwargs)
|
39 |
+
|
40 |
+
return res
|
41 |
+
|
42 |
+
|
43 |
+
class LoadProcessor():
|
44 |
+
def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6):
|
45 |
+
super().__init__()
|
46 |
+
self.original_processor = original_processor
|
47 |
+
self.name = name
|
48 |
+
self.img0_dict = img0_dict
|
49 |
+
self.img1_dict = img1_dict
|
50 |
+
self.alpha = alpha
|
51 |
+
self.beta = beta
|
52 |
+
self.lamd = lamd
|
53 |
+
self.id = 0
|
54 |
+
|
55 |
+
def parent_call(
|
56 |
+
self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
57 |
+
):
|
58 |
+
residual = hidden_states
|
59 |
+
|
60 |
+
if attn.spatial_norm is not None:
|
61 |
+
hidden_states = attn.spatial_norm(hidden_states)
|
62 |
+
|
63 |
+
input_ndim = hidden_states.ndim
|
64 |
+
|
65 |
+
if input_ndim == 4:
|
66 |
+
batch_size, channel, height, width = hidden_states.shape
|
67 |
+
hidden_states = hidden_states.view(
|
68 |
+
batch_size, channel, height * width).transpose(1, 2)
|
69 |
+
|
70 |
+
batch_size, sequence_length, _ = (
|
71 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
72 |
+
)
|
73 |
+
attention_mask = attn.prepare_attention_mask(
|
74 |
+
attention_mask, sequence_length, batch_size)
|
75 |
+
|
76 |
+
if attn.group_norm is not None:
|
77 |
+
hidden_states = attn.group_norm(
|
78 |
+
hidden_states.transpose(1, 2)).transpose(1, 2)
|
79 |
+
|
80 |
+
query = attn.to_q(hidden_states) + scale * \
|
81 |
+
self.original_processor.to_q_lora(hidden_states)
|
82 |
+
query = attn.head_to_batch_dim(query)
|
83 |
+
|
84 |
+
if encoder_hidden_states is None:
|
85 |
+
encoder_hidden_states = hidden_states
|
86 |
+
elif attn.norm_cross:
|
87 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
88 |
+
encoder_hidden_states)
|
89 |
+
|
90 |
+
key = attn.to_k(encoder_hidden_states) + scale * \
|
91 |
+
self.original_processor.to_k_lora(encoder_hidden_states)
|
92 |
+
value = attn.to_v(encoder_hidden_states) + scale * \
|
93 |
+
self.original_processor.to_v_lora(encoder_hidden_states)
|
94 |
+
|
95 |
+
key = attn.head_to_batch_dim(key)
|
96 |
+
value = attn.head_to_batch_dim(value)
|
97 |
+
|
98 |
+
attention_probs = attn.get_attention_scores(
|
99 |
+
query, key, attention_mask)
|
100 |
+
hidden_states = torch.bmm(attention_probs, value)
|
101 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
102 |
+
|
103 |
+
# linear proj
|
104 |
+
hidden_states = attn.to_out[0](
|
105 |
+
hidden_states) + scale * self.original_processor.to_out_lora(hidden_states)
|
106 |
+
# dropout
|
107 |
+
hidden_states = attn.to_out[1](hidden_states)
|
108 |
+
|
109 |
+
if input_ndim == 4:
|
110 |
+
hidden_states = hidden_states.transpose(
|
111 |
+
-1, -2).reshape(batch_size, channel, height, width)
|
112 |
+
|
113 |
+
if attn.residual_connection:
|
114 |
+
hidden_states = hidden_states + residual
|
115 |
+
|
116 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
117 |
+
|
118 |
+
return hidden_states
|
119 |
+
|
120 |
+
def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
|
121 |
+
# Is self attention
|
122 |
+
if encoder_hidden_states is None:
|
123 |
+
# hardcode timestep
|
124 |
+
if self.id < 50 * self.lamd:
|
125 |
+
map0 = self.img0_dict[self.name][self.id]
|
126 |
+
map1 = self.img1_dict[self.name][self.id]
|
127 |
+
cross_map = self.beta * hidden_states + \
|
128 |
+
(1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
|
129 |
+
# cross_map = self.beta * hidden_states + \
|
130 |
+
# (1 - self.beta) * slerp(map0, map1, self.alpha)
|
131 |
+
# cross_map = slerp(slerp(map0, map1, self.alpha),
|
132 |
+
# hidden_states, self.beta)
|
133 |
+
# cross_map = hidden_states
|
134 |
+
# cross_map = torch.cat(
|
135 |
+
# ((1 - self.alpha) * map0, self.alpha * map1), dim=1)
|
136 |
+
|
137 |
+
# res = self.original_processor(attn, hidden_states, *args,
|
138 |
+
# encoder_hidden_states=cross_map,
|
139 |
+
# attention_mask=attention_mask,
|
140 |
+
# temb=temb, **kwargs)
|
141 |
+
res = self.parent_call(attn, hidden_states, *args,
|
142 |
+
encoder_hidden_states=cross_map,
|
143 |
+
attention_mask=attention_mask,
|
144 |
+
**kwargs)
|
145 |
+
else:
|
146 |
+
res = self.original_processor(attn, hidden_states, *args,
|
147 |
+
encoder_hidden_states=encoder_hidden_states,
|
148 |
+
attention_mask=attention_mask,
|
149 |
+
**kwargs)
|
150 |
+
|
151 |
+
self.id += 1
|
152 |
+
# if self.id == len(self.img0_dict[self.name]):
|
153 |
+
if self.id == len(self.img0_dict[self.name]):
|
154 |
+
self.id = 0
|
155 |
+
else:
|
156 |
+
res = self.original_processor(attn, hidden_states, *args,
|
157 |
+
encoder_hidden_states=encoder_hidden_states,
|
158 |
+
attention_mask=attention_mask,
|
159 |
+
**kwargs)
|
160 |
+
|
161 |
+
return res
|
162 |
+
|
163 |
+
|
164 |
+
class DiffMorpherPipeline(StableDiffusionPipeline):
|
165 |
+
|
166 |
+
def __init__(self,
|
167 |
+
vae: AutoencoderKL,
|
168 |
+
text_encoder: CLIPTextModel,
|
169 |
+
tokenizer: CLIPTokenizer,
|
170 |
+
unet: UNet2DConditionModel,
|
171 |
+
scheduler: KarrasDiffusionSchedulers,
|
172 |
+
safety_checker: StableDiffusionSafetyChecker,
|
173 |
+
feature_extractor: CLIPImageProcessor,
|
174 |
+
requires_safety_checker: bool = True,
|
175 |
+
):
|
176 |
+
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
|
177 |
+
safety_checker, feature_extractor, requires_safety_checker)
|
178 |
+
self.img0_dict = dict()
|
179 |
+
self.img1_dict = dict()
|
180 |
+
|
181 |
+
def inv_step(
|
182 |
+
self,
|
183 |
+
model_output: torch.FloatTensor,
|
184 |
+
timestep: int,
|
185 |
+
x: torch.FloatTensor,
|
186 |
+
eta=0.,
|
187 |
+
verbose=False
|
188 |
+
):
|
189 |
+
"""
|
190 |
+
Inverse sampling for DDIM Inversion
|
191 |
+
"""
|
192 |
+
if verbose:
|
193 |
+
print("timestep: ", timestep)
|
194 |
+
next_step = timestep
|
195 |
+
timestep = min(timestep - self.scheduler.config.num_train_timesteps //
|
196 |
+
self.scheduler.num_inference_steps, 999)
|
197 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[
|
198 |
+
timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
199 |
+
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
|
200 |
+
beta_prod_t = 1 - alpha_prod_t
|
201 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
202 |
+
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
|
203 |
+
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
|
204 |
+
return x_next, pred_x0
|
205 |
+
|
206 |
+
@torch.no_grad()
|
207 |
+
def invert(
|
208 |
+
self,
|
209 |
+
image: torch.Tensor,
|
210 |
+
prompt,
|
211 |
+
num_inference_steps=50,
|
212 |
+
num_actual_inference_steps=None,
|
213 |
+
guidance_scale=1.,
|
214 |
+
eta=0.0,
|
215 |
+
**kwds):
|
216 |
+
"""
|
217 |
+
invert a real image into noise map with determinisc DDIM inversion
|
218 |
+
"""
|
219 |
+
DEVICE = torch.device(
|
220 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
221 |
+
batch_size = image.shape[0]
|
222 |
+
if isinstance(prompt, list):
|
223 |
+
if batch_size == 1:
|
224 |
+
image = image.expand(len(prompt), -1, -1, -1)
|
225 |
+
elif isinstance(prompt, str):
|
226 |
+
if batch_size > 1:
|
227 |
+
prompt = [prompt] * batch_size
|
228 |
+
|
229 |
+
# text embeddings
|
230 |
+
text_input = self.tokenizer(
|
231 |
+
prompt,
|
232 |
+
padding="max_length",
|
233 |
+
max_length=77,
|
234 |
+
return_tensors="pt"
|
235 |
+
)
|
236 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
237 |
+
print("input text embeddings :", text_embeddings.shape)
|
238 |
+
# define initial latents
|
239 |
+
latents = self.image2latent(image)
|
240 |
+
|
241 |
+
# unconditional embedding for classifier free guidance
|
242 |
+
if guidance_scale > 1.:
|
243 |
+
max_length = text_input.input_ids.shape[-1]
|
244 |
+
unconditional_input = self.tokenizer(
|
245 |
+
[""] * batch_size,
|
246 |
+
padding="max_length",
|
247 |
+
max_length=77,
|
248 |
+
return_tensors="pt"
|
249 |
+
)
|
250 |
+
unconditional_embeddings = self.text_encoder(
|
251 |
+
unconditional_input.input_ids.to(DEVICE))[0]
|
252 |
+
text_embeddings = torch.cat(
|
253 |
+
[unconditional_embeddings, text_embeddings], dim=0)
|
254 |
+
|
255 |
+
print("latents shape: ", latents.shape)
|
256 |
+
# interative sampling
|
257 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
258 |
+
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
259 |
+
# print("attributes: ", self.scheduler.__dict__)
|
260 |
+
latents_list = [latents]
|
261 |
+
pred_x0_list = [latents]
|
262 |
+
for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
|
263 |
+
if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
|
264 |
+
continue
|
265 |
+
|
266 |
+
if guidance_scale > 1.:
|
267 |
+
model_inputs = torch.cat([latents] * 2)
|
268 |
+
else:
|
269 |
+
model_inputs = latents
|
270 |
+
|
271 |
+
# predict the noise
|
272 |
+
noise_pred = self.unet(
|
273 |
+
model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
274 |
+
if guidance_scale > 1.:
|
275 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
276 |
+
noise_pred = noise_pred_uncon + guidance_scale * \
|
277 |
+
(noise_pred_con - noise_pred_uncon)
|
278 |
+
# compute the previous noise sample x_t-1 -> x_t
|
279 |
+
latents, pred_x0 = self.inv_step(noise_pred, t, latents)
|
280 |
+
latents_list.append(latents)
|
281 |
+
pred_x0_list.append(pred_x0)
|
282 |
+
|
283 |
+
return latents
|
284 |
+
|
285 |
+
@torch.no_grad()
|
286 |
+
def ddim_inversion(self, latent, cond):
|
287 |
+
timesteps = reversed(self.scheduler.timesteps)
|
288 |
+
with torch.autocast(device_type='cuda', dtype=torch.float32):
|
289 |
+
for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")):
|
290 |
+
cond_batch = cond.repeat(latent.shape[0], 1, 1)
|
291 |
+
|
292 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[t]
|
293 |
+
alpha_prod_t_prev = (
|
294 |
+
self.scheduler.alphas_cumprod[timesteps[i - 1]]
|
295 |
+
if i > 0 else self.scheduler.final_alpha_cumprod
|
296 |
+
)
|
297 |
+
|
298 |
+
mu = alpha_prod_t ** 0.5
|
299 |
+
mu_prev = alpha_prod_t_prev ** 0.5
|
300 |
+
sigma = (1 - alpha_prod_t) ** 0.5
|
301 |
+
sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
|
302 |
+
|
303 |
+
eps = self.unet(
|
304 |
+
latent, t, encoder_hidden_states=cond_batch).sample
|
305 |
+
|
306 |
+
pred_x0 = (latent - sigma_prev * eps) / mu_prev
|
307 |
+
latent = mu * pred_x0 + sigma * eps
|
308 |
+
# if save_latents:
|
309 |
+
# torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
|
310 |
+
# torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
|
311 |
+
return latent
|
312 |
+
|
313 |
+
def step(
|
314 |
+
self,
|
315 |
+
model_output: torch.FloatTensor,
|
316 |
+
timestep: int,
|
317 |
+
x: torch.FloatTensor,
|
318 |
+
):
|
319 |
+
"""
|
320 |
+
predict the sample of the next step in the denoise process.
|
321 |
+
"""
|
322 |
+
prev_timestep = timestep - \
|
323 |
+
self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
324 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
325 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[
|
326 |
+
prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
|
327 |
+
beta_prod_t = 1 - alpha_prod_t
|
328 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
329 |
+
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
|
330 |
+
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
|
331 |
+
return x_prev, pred_x0
|
332 |
+
|
333 |
+
@torch.no_grad()
|
334 |
+
def image2latent(self, image):
|
335 |
+
DEVICE = torch.device(
|
336 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
337 |
+
if type(image) is Image:
|
338 |
+
image = np.array(image)
|
339 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
340 |
+
image = image.permute(2, 0, 1).unsqueeze(0)
|
341 |
+
# input image density range [-1, 1]
|
342 |
+
latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean
|
343 |
+
latents = latents * 0.18215
|
344 |
+
return latents
|
345 |
+
|
346 |
+
@torch.no_grad()
|
347 |
+
def latent2image(self, latents, return_type='np'):
|
348 |
+
latents = 1 / 0.18215 * latents.detach()
|
349 |
+
image = self.vae.decode(latents)['sample']
|
350 |
+
if return_type == 'np':
|
351 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
352 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
353 |
+
image = (image * 255).astype(np.uint8)
|
354 |
+
elif return_type == "pt":
|
355 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
356 |
+
|
357 |
+
return image
|
358 |
+
|
359 |
+
def latent2image_grad(self, latents):
|
360 |
+
latents = 1 / 0.18215 * latents
|
361 |
+
image = self.vae.decode(latents)['sample']
|
362 |
+
|
363 |
+
return image # range [-1, 1]
|
364 |
+
|
365 |
+
@torch.no_grad()
|
366 |
+
def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None):
|
367 |
+
# latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \
|
368 |
+
# torch.sin(alpha * torch.pi / 2) * img_noise_1
|
369 |
+
# latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1
|
370 |
+
# latents = latents / ((1 - alpha) ** 2 + alpha ** 2)
|
371 |
+
latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain)
|
372 |
+
text_embeddings = (1 - alpha) * text_embeddings_0 + \
|
373 |
+
alpha * text_embeddings_1
|
374 |
+
|
375 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
376 |
+
if use_lora:
|
377 |
+
if fix_lora is not None:
|
378 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora)
|
379 |
+
else:
|
380 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, alpha)
|
381 |
+
|
382 |
+
for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")):
|
383 |
+
|
384 |
+
if guidance_scale > 1.:
|
385 |
+
model_inputs = torch.cat([latents] * 2)
|
386 |
+
else:
|
387 |
+
model_inputs = latents
|
388 |
+
if unconditioning is not None and isinstance(unconditioning, list):
|
389 |
+
_, text_embeddings = text_embeddings.chunk(2)
|
390 |
+
text_embeddings = torch.cat(
|
391 |
+
[unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
|
392 |
+
# predict the noise
|
393 |
+
noise_pred = self.unet(
|
394 |
+
model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
395 |
+
if guidance_scale > 1.0:
|
396 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(
|
397 |
+
2, dim=0)
|
398 |
+
noise_pred = noise_pred_uncon + guidance_scale * \
|
399 |
+
(noise_pred_con - noise_pred_uncon)
|
400 |
+
# compute the previous noise sample x_t -> x_t-1
|
401 |
+
latents = self.scheduler.step(
|
402 |
+
noise_pred, t, latents, return_dict=False)[0]
|
403 |
+
return latents
|
404 |
+
|
405 |
+
@torch.no_grad()
|
406 |
+
def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size):
|
407 |
+
DEVICE = torch.device(
|
408 |
+
"cuda") if torch.cuda.is_available() else torch.device("cpu")
|
409 |
+
# text embeddings
|
410 |
+
text_input = self.tokenizer(
|
411 |
+
prompt,
|
412 |
+
padding="max_length",
|
413 |
+
max_length=77,
|
414 |
+
return_tensors="pt"
|
415 |
+
)
|
416 |
+
text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0]
|
417 |
+
|
418 |
+
if guidance_scale > 1.:
|
419 |
+
if neg_prompt:
|
420 |
+
uc_text = neg_prompt
|
421 |
+
else:
|
422 |
+
uc_text = ""
|
423 |
+
unconditional_input = self.tokenizer(
|
424 |
+
[uc_text] * batch_size,
|
425 |
+
padding="max_length",
|
426 |
+
max_length=77,
|
427 |
+
return_tensors="pt"
|
428 |
+
)
|
429 |
+
unconditional_embeddings = self.text_encoder(
|
430 |
+
unconditional_input.input_ids.to(DEVICE))[0]
|
431 |
+
text_embeddings = torch.cat(
|
432 |
+
[unconditional_embeddings, text_embeddings], dim=0)
|
433 |
+
|
434 |
+
return text_embeddings
|
435 |
+
|
436 |
+
def __call__(
|
437 |
+
self,
|
438 |
+
img_0=None,
|
439 |
+
img_1=None,
|
440 |
+
img_path_0=None,
|
441 |
+
img_path_1=None,
|
442 |
+
prompt_0="",
|
443 |
+
prompt_1="",
|
444 |
+
imgs=[],
|
445 |
+
img_paths=None,
|
446 |
+
prompts=None,
|
447 |
+
save_lora_dir="./lora",
|
448 |
+
load_lora_path_0=None,
|
449 |
+
load_lora_path_1=None,
|
450 |
+
load_lora_paths=None,
|
451 |
+
lora_steps=200,
|
452 |
+
lora_lr=2e-4,
|
453 |
+
lora_rank=16,
|
454 |
+
batch_size=1,
|
455 |
+
height=512,
|
456 |
+
width=512,
|
457 |
+
num_inference_steps=50,
|
458 |
+
num_actual_inference_steps=None,
|
459 |
+
guidance_scale=1,
|
460 |
+
attn_beta=0,
|
461 |
+
lamd=0.6,
|
462 |
+
use_lora=True,
|
463 |
+
use_adain=True,
|
464 |
+
use_reschedule=True,
|
465 |
+
output_path = "./results",
|
466 |
+
num_frames=50,
|
467 |
+
fix_lora=None,
|
468 |
+
progress=tqdm,
|
469 |
+
unconditioning=None,
|
470 |
+
neg_prompt=None,
|
471 |
+
save_intermediates=False,
|
472 |
+
**kwds):
|
473 |
+
|
474 |
+
|
475 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
476 |
+
self.use_lora = use_lora
|
477 |
+
self.use_adain = use_adain
|
478 |
+
self.use_reschedule = use_reschedule
|
479 |
+
self.output_path = output_path
|
480 |
+
|
481 |
+
|
482 |
+
imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths]
|
483 |
+
assert len(prompts) == len(imgs)
|
484 |
+
|
485 |
+
# if img_path_0 or img_0:
|
486 |
+
# img_paths = [img_path_0, img_path_1]
|
487 |
+
# prompts = [prompt_0, prompt_1]
|
488 |
+
# load_lora_paths = [load_lora_path_0, load_lora_path_1]
|
489 |
+
|
490 |
+
# if img_0:
|
491 |
+
# imgs.append(Image.fromarray(img_0).convert("RGB"))
|
492 |
+
# if img_1:
|
493 |
+
# imgs.append(Image.fromarray(img_1).convert("RGB"))
|
494 |
+
# if imgs is None:
|
495 |
+
# imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths]
|
496 |
+
# if len(prompts) < len(imgs):
|
497 |
+
# prompts += ["" for _ in range(len(imgs) - len(prompts))]
|
498 |
+
|
499 |
+
if self.use_lora:
|
500 |
+
loras = []
|
501 |
+
print("Loading lora...")
|
502 |
+
for i, (img, prompt) in enumerate(zip(imgs, prompts)):
|
503 |
+
if len(load_lora_paths) == i:
|
504 |
+
|
505 |
+
weight_name = f"{output_path.split('/')[-1]}_lora_{i}.ckpt"
|
506 |
+
load_lora_paths.append(save_lora_dir + "/" + weight_name)
|
507 |
+
if not os.path.exists(load_lora_paths[i]):
|
508 |
+
train_lora(img, prompt, save_lora_dir, None, self.tokenizer, self.text_encoder,
|
509 |
+
self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
510 |
+
print(f"Load from {load_lora_paths[i]}.")
|
511 |
+
if load_lora_paths[i].endswith(".safetensors"):
|
512 |
+
loras.append(safetensors.torch.load_file(
|
513 |
+
load_lora_paths[i], device="cpu"))
|
514 |
+
else:
|
515 |
+
loras.append(torch.load(load_lora_paths[i], map_location="cpu"))
|
516 |
+
|
517 |
+
# if not load_lora_path_1:
|
518 |
+
# weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
|
519 |
+
# load_lora_path_1 = save_lora_dir + "/" + weight_name
|
520 |
+
# if not os.path.exists(load_lora_path_1):
|
521 |
+
# train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder,
|
522 |
+
# self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
|
523 |
+
# print(f"Load from {load_lora_path_1}.")
|
524 |
+
# if load_lora_path_1.endswith(".safetensors"):
|
525 |
+
# lora_1 = safetensors.torch.load_file(
|
526 |
+
# load_lora_path_1, device="cpu")
|
527 |
+
# else:
|
528 |
+
# lora_1 = torch.load(load_lora_path_1, map_location="cpu")
|
529 |
+
|
530 |
+
def morph(alpha_list, progress, desc, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1):
|
531 |
+
images = []
|
532 |
+
if attn_beta is not None:
|
533 |
+
|
534 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora)
|
535 |
+
attn_processor_dict = {}
|
536 |
+
for k in self.unet.attn_processors.keys():
|
537 |
+
if do_replace_attn(k):
|
538 |
+
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
539 |
+
self.img0_dict, k)
|
540 |
+
else:
|
541 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
542 |
+
self.unet.set_attn_processor(attn_processor_dict)
|
543 |
+
|
544 |
+
latents = self.cal_latent(
|
545 |
+
num_inference_steps,
|
546 |
+
guidance_scale,
|
547 |
+
unconditioning,
|
548 |
+
img_noise_0,
|
549 |
+
img_noise_1,
|
550 |
+
text_embeddings_0,
|
551 |
+
text_embeddings_1,
|
552 |
+
lora_0,
|
553 |
+
lora_1,
|
554 |
+
alpha_list[0],
|
555 |
+
False,
|
556 |
+
fix_lora
|
557 |
+
)
|
558 |
+
first_image = self.latent2image(latents)
|
559 |
+
first_image = Image.fromarray(first_image)
|
560 |
+
# if save_intermediates:
|
561 |
+
# first_image.save(f"{self.output_path}/{0:02d}.png")
|
562 |
+
|
563 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora)
|
564 |
+
attn_processor_dict = {}
|
565 |
+
for k in self.unet.attn_processors.keys():
|
566 |
+
if do_replace_attn(k):
|
567 |
+
attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
|
568 |
+
self.img1_dict, k)
|
569 |
+
else:
|
570 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
571 |
+
|
572 |
+
self.unet.set_attn_processor(attn_processor_dict)
|
573 |
+
|
574 |
+
latents = self.cal_latent(
|
575 |
+
num_inference_steps,
|
576 |
+
guidance_scale,
|
577 |
+
unconditioning,
|
578 |
+
img_noise_0,
|
579 |
+
img_noise_1,
|
580 |
+
text_embeddings_0,
|
581 |
+
text_embeddings_1,
|
582 |
+
lora_0,
|
583 |
+
lora_1,
|
584 |
+
alpha_list[-1],
|
585 |
+
False,
|
586 |
+
fix_lora
|
587 |
+
)
|
588 |
+
last_image = self.latent2image(latents)
|
589 |
+
last_image = Image.fromarray(last_image)
|
590 |
+
# if save_intermediates:
|
591 |
+
# last_image.save(
|
592 |
+
# f"{self.output_path}/{num_frames - 1:02d}.png")
|
593 |
+
|
594 |
+
for i in progress.tqdm(range(1, num_frames - 1), desc=desc):
|
595 |
+
alpha = alpha_list[i]
|
596 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora)
|
597 |
+
attn_processor_dict = {}
|
598 |
+
for k in self.unet.attn_processors.keys():
|
599 |
+
if do_replace_attn(k):
|
600 |
+
attn_processor_dict[k] = LoadProcessor(
|
601 |
+
self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd)
|
602 |
+
else:
|
603 |
+
attn_processor_dict[k] = self.unet.attn_processors[k]
|
604 |
+
|
605 |
+
self.unet.set_attn_processor(attn_processor_dict)
|
606 |
+
|
607 |
+
latents = self.cal_latent(
|
608 |
+
num_inference_steps,
|
609 |
+
guidance_scale,
|
610 |
+
unconditioning,
|
611 |
+
img_noise_0,
|
612 |
+
img_noise_1,
|
613 |
+
text_embeddings_0,
|
614 |
+
text_embeddings_1,
|
615 |
+
lora_0,
|
616 |
+
lora_1,
|
617 |
+
alpha_list[i],
|
618 |
+
False,
|
619 |
+
fix_lora
|
620 |
+
)
|
621 |
+
image = self.latent2image(latents)
|
622 |
+
image = Image.fromarray(image)
|
623 |
+
# if save_intermediates:
|
624 |
+
# image.save(f"{self.output_path}/{i:02d}.png")
|
625 |
+
images.append(image)
|
626 |
+
|
627 |
+
images = [first_image] + images + [last_image]
|
628 |
+
|
629 |
+
else:
|
630 |
+
for k, alpha in enumerate(alpha_list):
|
631 |
+
|
632 |
+
latents = self.cal_latent(
|
633 |
+
num_inference_steps,
|
634 |
+
guidance_scale,
|
635 |
+
unconditioning,
|
636 |
+
img_noise_0,
|
637 |
+
img_noise_1,
|
638 |
+
text_embeddings_0,
|
639 |
+
text_embeddings_1,
|
640 |
+
lora_0,
|
641 |
+
lora_1,
|
642 |
+
alpha_list[k],
|
643 |
+
self.use_lora,
|
644 |
+
fix_lora
|
645 |
+
)
|
646 |
+
image = self.latent2image(latents)
|
647 |
+
image = Image.fromarray(image)
|
648 |
+
# if save_intermediates:
|
649 |
+
# image.save(f"{self.output_path}/{k:02d}.png")
|
650 |
+
images.append(image)
|
651 |
+
|
652 |
+
return images
|
653 |
+
|
654 |
+
images = []
|
655 |
+
|
656 |
+
for img_0, img_1, prompt_0, prompt_1, lora_0, lora_1 in zip(imgs[:-1], imgs[1:], prompts[:-1], prompts[1:], loras[:-1], loras[1:]):
|
657 |
+
text_embeddings_0 = self.get_text_embeddings(
|
658 |
+
prompt_0, guidance_scale, neg_prompt, batch_size)
|
659 |
+
text_embeddings_1 = self.get_text_embeddings(
|
660 |
+
prompt_1, guidance_scale, neg_prompt, batch_size)
|
661 |
+
img_0 = get_img(img_0)
|
662 |
+
img_1 = get_img(img_1)
|
663 |
+
if self.use_lora:
|
664 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, 0)
|
665 |
+
img_noise_0 = self.ddim_inversion(
|
666 |
+
self.image2latent(img_0), text_embeddings_0)
|
667 |
+
if self.use_lora:
|
668 |
+
self.unet = load_lora(self.unet, lora_0, lora_1, 1)
|
669 |
+
img_noise_1 = self.ddim_inversion(
|
670 |
+
self.image2latent(img_1), text_embeddings_1)
|
671 |
+
|
672 |
+
print("latents shape: ", img_noise_0.shape)
|
673 |
+
|
674 |
+
with torch.no_grad():
|
675 |
+
if self.use_reschedule:
|
676 |
+
alpha_scheduler = AlphaScheduler()
|
677 |
+
alpha_list = list(torch.linspace(0, 1, num_frames))
|
678 |
+
images_pt = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1)
|
679 |
+
images_pt = [transforms.ToTensor()(img).unsqueeze(0)
|
680 |
+
for img in images_pt]
|
681 |
+
alpha_scheduler.from_imgs(images_pt)
|
682 |
+
alpha_list = alpha_scheduler.get_list()
|
683 |
+
print(alpha_list)
|
684 |
+
images_ = morph(alpha_list, progress, "Reschedule...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1)
|
685 |
+
else:
|
686 |
+
alpha_list = list(torch.linspace(0, 1, num_frames))
|
687 |
+
print(alpha_list)
|
688 |
+
images_ = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1)
|
689 |
+
|
690 |
+
if len(images) == 0:
|
691 |
+
images = images_
|
692 |
+
else:
|
693 |
+
images += images_[1:]
|
694 |
+
|
695 |
+
if save_intermediates:
|
696 |
+
for i, image in enumerate(images):
|
697 |
+
image.save(f"{self.output_path}/{i:02d}.png")
|
698 |
+
|
699 |
+
return images
|
requirements.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.23.0
|
2 |
+
diffusers==0.17.1
|
3 |
+
einops==0.7.0
|
4 |
+
gradio==4.7.1
|
5 |
+
numpy==1.26.1
|
6 |
+
opencv_python==4.5.5.64
|
7 |
+
packaging==23.2
|
8 |
+
Pillow==10.1.0
|
9 |
+
safetensors==0.4.0
|
10 |
+
tqdm==4.65.0
|
11 |
+
transformers==4.34.1
|
12 |
+
torch
|
13 |
+
torchvision
|
14 |
+
lpips
|