diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..693ef229a1b4e9703a19e6c45df3a1cae968e7eb 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,8 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/dog_sit.png filter=lfs diff=lfs merge=lfs -text +assets/dog.png filter=lfs diff=lfs merge=lfs -text +assets/teaser.gif filter=lfs diff=lfs merge=lfs -text +assets/Teaser.png filter=lfs diff=lfs merge=lfs -text +multi_image/assets/realdog.gif filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..31922c938bb364cbb08986bce8d7f81b27abe440 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,10 @@ +S-Lab License 1.0  +  +Copyright 2023 S-Lab +Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met: +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work. +  diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..14f0e670ec48a16bdf82f3e46de14e5d3e690a77 --- /dev/null +++ b/README.md @@ -0,0 +1,131 @@ +

+

DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing

+

CVPR 2024

+

+ Kaiwen Zhang +    + Yifan Zhou +    + Xudong Xu +    + Xingang Pan +    + Bo Dai +

+
+ +

+ Corresponding Author +

+ +
+ +
+ +

+ arXiv + page + Twitter + Twitter +

+
+

+ +## Web Demos + +[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/KaiwenZhang/DiffMorpher) + +

+ Huggingface +

+ + + +## Requirements +To install the requirements, run the following in your environment first: +```bash +pip install -r requirements.txt +``` +To run the code with CUDA properly, you can comment out `torch` and `torchvision` in `requirement.txt`, and install the appropriate version of `torch` and `torchvision` according to the instructions on [PyTorch](https://pytorch.org/get-started/locally/). + +You can also download the pretrained model *Stable Diffusion v2.1-base* from [Huggingface](https://huggingface.co./stabilityai/stable-diffusion-2-1-base), and specify the `model_path` to your local directory. + +## Run Gradio UI +To start the Gradio UI of DiffMorpher, run the following in your environment: +```bash +python app.py +``` +Then, by default, you can access the UI at [http://127.0.0.1:7860](http://127.0.0.1:7860). + +## Run the code +You can also run the code with the following command: +```bash +python main.py \ + --image_path_0 [image_path_0] --image_path_1 [image_path_1] \ + --prompt_0 [prompt_0] --prompt_1 [prompt_1] \ + --output_path [output_path] \ + --use_adain --use_reschedule --save_inter +``` +The script also supports the following options: + +- `--image_path_0`: Path of the first image (default: "") +- `--prompt_0`: Prompt of the first image (default: "") +- `--image_path_1`: Path of the second image (default: "") +- `--prompt_1`: Prompt of the second image (default: "") +- `--model_path`: Pretrained model path (default: "stabilityai/stable-diffusion-2-1-base") +- `--output_path`: Path of the output image (default: "") +- `--save_lora_dir`: Path of the output lora directory (default: "./lora") +- `--load_lora_path_0`: Path of the lora directory of the first image (default: "") +- `--load_lora_path_1`: Path of the lora directory of the second image (default: "") +- `--use_adain`: Use AdaIN (default: False) +- `--use_reschedule`: Use reschedule sampling (default: False) +- `--lamb`: Hyperparameter $\lambda \in [0,1]$ for self-attention replacement, where a larger $\lambda$ indicates more replacements (default: 0.6) +- `--fix_lora_value`: Fix lora value (default: LoRA Interpolation, not fixed) +- `--save_inter`: Save intermediate results (default: False) +- `--num_frames`: Number of frames to generate (default: 50) +- `--duration`: Duration of each frame (default: 50) + +Examples: +```bash +python main.py \ + --image_path_0 ./assets/Trump.jpg --image_path_1 ./assets/Biden.jpg \ + --prompt_0 "A photo of an American man" --prompt_1 "A photo of an American man" \ + --output_path "./results/Trump_Biden" \ + --use_adain --use_reschedule --save_inter +``` + +```bash +python main.py \ + --image_path_0 ./assets/vangogh.jpg --image_path_1 ./assets/pearlgirl.jpg \ + --prompt_0 "An oil painting of a man" --prompt_1 "An oil painting of a woman" \ + --output_path "./results/vangogh_pearlgirl" \ + --use_adain --use_reschedule --save_inter +``` + +```bash +python main.py \ + --image_path_0 ./assets/lion.png --image_path_1 ./assets/tiger.png \ + --prompt_0 "A photo of a lion" --prompt_1 "A photo of a tiger" \ + --output_path "./results/lion_tiger" \ + --use_adain --use_reschedule --save_inter +``` + +## MorphBench +To evaluate the effectiveness of our methods, we present *MorphBench*, the first benchmark dataset for assessing image morphing of general objects. You can download the dataset from [Google Drive](https://drive.google.com/file/d/1NWPzJhOgP-udP_wYbd0selRG4cu8xsu4/view?usp=sharing) or [Baidu Netdisk](https://pan.baidu.com/s/1J3xE3OJdEhKyoc1QObyYaA?pwd=putk). + + +## License +The code related to the DiffMorpher algorithm is licensed under [LICENSE](LICENSE.txt). + +However, this project is mostly built on the open-sourse library [diffusers](https://github.com/huggingface/diffusers), which is under a separate license terms [Apache License 2.0](https://github.com/huggingface/diffusers/blob/main/LICENSE). (Cheers to the community as well!) + +## Citation + +```bibtex +@article{zhang2023diffmorpher, + title={DiffMorpher: Unleashing the Capability of Diffusion Models for Image Morphing}, + author={Zhang, Kaiwen and Zhou, Yifan and Xu, Xudong and Pan, Xingang and Dai, Bo}, + journal={arXiv preprint arXiv:2312.07409}, + year={2023} +} +``` diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..1c3c726f88b86d568e27c4d4c641e1ad52de4f16 --- /dev/null +++ b/app.py @@ -0,0 +1,315 @@ +import os +import torch +import numpy as np +import cv2 +import gradio as gr +from PIL import Image +from datetime import datetime +from model import DiffMorpherPipeline +from utils.lora_utils import train_lora + +LENGTH=450 + +def train_lora_interface( + image, + prompt, + model_path, + output_path, + lora_steps, + lora_rank, + lora_lr, + num +): + os.makedirs(output_path, exist_ok=True) + train_lora(image, prompt, output_path, model_path, + lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_{num}.ckpt", progress=gr.Progress()) + return f"Train LoRA {'A' if num == 0 else 'B'} Done!" + +def run_diffmorpher( + image_0, + image_1, + prompt_0, + prompt_1, + model_path, + lora_mode, + lamb, + use_adain, + use_reschedule, + num_frames, + fps, + save_inter, + load_lora_path_0, + load_lora_path_1, + output_path +): + run_id = datetime.now().strftime("%H%M") + "_" + datetime.now().strftime("%Y%m%d") + os.makedirs(output_path, exist_ok=True) + morpher_pipeline = DiffMorpherPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cuda") + if lora_mode == "Fix LoRA A": + fix_lora = 0 + elif lora_mode == "Fix LoRA B": + fix_lora = 1 + else: + fix_lora = None + if not load_lora_path_0: + load_lora_path_0 = f"{output_path}/lora_0.ckpt" + if not load_lora_path_1: + load_lora_path_1 = f"{output_path}/lora_1.ckpt" + images = morpher_pipeline( + img_0=image_0, + img_1=image_1, + prompt_0=prompt_0, + prompt_1=prompt_1, + load_lora_path_0=load_lora_path_0, + load_lora_path_1=load_lora_path_1, + lamb=lamb, + use_adain=use_adain, + use_reschedule=use_reschedule, + num_frames=num_frames, + fix_lora=fix_lora, + save_intermediates=save_inter, + progress=gr.Progress() + ) + video_path = f"{output_path}/{run_id}.mp4" + video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (512, 512)) + for i, image in enumerate(images): + # image.save(f"{output_path}/{i}.png") + video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)) + video.release() + cv2.destroyAllWindows() + return gr.Video(value=video_path, format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False) + +def run_all( + image_0, + image_1, + prompt_0, + prompt_1, + model_path, + lora_mode, + lamb, + use_adain, + use_reschedule, + num_frames, + fps, + save_inter, + load_lora_path_0, + load_lora_path_1, + output_path, + lora_steps, + lora_rank, + lora_lr +): + os.makedirs(output_path, exist_ok=True) + train_lora(image_0, prompt_0, output_path, model_path, + lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_0.ckpt", progress=gr.Progress()) + train_lora(image_1, prompt_1, output_path, model_path, + lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_1.ckpt", progress=gr.Progress()) + return run_diffmorpher( + image_0, + image_1, + prompt_0, + prompt_1, + model_path, + lora_mode, + lamb, + use_adain, + use_reschedule, + num_frames, + fps, + save_inter, + load_lora_path_0, + load_lora_path_1, + output_path + ) + +with gr.Blocks() as demo: + + with gr.Row(): + gr.Markdown(""" + # Official Implementation of [DiffMorpher](https://kevin-thu.github.io/DiffMorpher_page/) + """) + + original_image_0, original_image_1 = gr.State(Image.open("assets/Trump.jpg").convert("RGB").resize((512,512), Image.BILINEAR)), gr.State(Image.open("assets/Biden.jpg").convert("RGB").resize((512,512), Image.BILINEAR)) + # key_points_0, key_points_1 = gr.State([]), gr.State([]) + # to_change_points = gr.State([]) + + with gr.Row(): + with gr.Column(): + input_img_0 = gr.Image(type="numpy", label="Input image A", value="assets/Trump.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True) + prompt_0 = gr.Textbox(label="Prompt for image A", value="a photo of an American man", interactive=True) + with gr.Row(): + train_lora_0_button = gr.Button("Train LoRA A") + train_lora_1_button = gr.Button("Train LoRA B") + # show_correspond_button = gr.Button("Show correspondence points") + with gr.Column(): + input_img_1 = gr.Image(type="numpy", label="Input image B ", value="assets/Biden.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True) + prompt_1 = gr.Textbox(label="Prompt for image B", value="a photo of an American man", interactive=True) + with gr.Row(): + clear_button = gr.Button("Clear All") + run_button = gr.Button("Run w/o LoRA training") + with gr.Column(): + output_video = gr.Video(format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False) + lora_progress_bar = gr.Textbox(label="Display LoRA training progress", interactive=False) + run_all_button = gr.Button("Run!") + # with gr.Column(): + # output_video = gr.Video(label="Output video", show_label=True, height=LENGTH, width=LENGTH) + + with gr.Row(): + gr.Markdown(""" + ### Usage: + 1. Upload two images (with correspondence) and fill out the prompts. + (It's recommended to change `[Output path]` accordingly.) + 2. Click **"Run!"** + + Or: + 1. Upload two images (with correspondence) and fill out the prompts. + 2. Click the **"Train LoRA A/B"** button to fit two LoRAs for two images respectively.
   + If you have trained LoRA A or LoRA B before, you can skip the step and fill the specific LoRA path in LoRA settings.
   + Trained LoRAs are saved to `[Output Path]/lora_0.ckpt` and `[Output Path]/lora_1.ckpt` by default. + 3. You might also change the settings below. + 4. Click **"Run w/o LoRA training"** + + ### Note: + 1. To speed up the generation process, you can **ruduce the number of frames** or **turn off "Use Reschedule"**. + 2. You can try the influence of different prompts. It seems that using the same prompts or aligned prompts works better. + ### Have fun! + """) + + with gr.Accordion(label="Algorithm Parameters"): + with gr.Tab("Basic Settings"): + with gr.Row(): + # local_models_dir = 'local_pretrained_models' + # local_models_choice = \ + # [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))] + model_path = gr.Text(value="stabilityai/stable-diffusion-2-1-base", + label="Diffusion Model Path", interactive=True + ) + lamb = gr.Slider(value=0.6, minimum=0, maximum=1, step=0.1, label="Lambda for attention replacement", interactive=True) + lora_mode = gr.Dropdown(value="LoRA Interp", + label="LoRA Interp. or Fix LoRA", + choices=["LoRA Interp", "Fix LoRA A", "Fix LoRA B"], + interactive=True + ) + use_adain = gr.Checkbox(value=True, label="Use AdaIN", interactive=True) + use_reschedule = gr.Checkbox(value=True, label="Use Reschedule", interactive=True) + with gr.Row(): + num_frames = gr.Number(value=16, minimum=0, label="Number of Frames", precision=0, interactive=True) + fps = gr.Number(value=8, minimum=0, label="FPS (Frame rate)", precision=0, interactive=True) + save_inter = gr.Checkbox(value=False, label="Save Intermediate Images", interactive=True) + output_path = gr.Text(value="./results", label="Output Path", interactive=True) + + with gr.Tab("LoRA Settings"): + with gr.Row(): + lora_steps = gr.Number(value=200, label="LoRA training steps", precision=0, interactive=True) + lora_lr = gr.Number(value=0.0002, label="LoRA learning rate", interactive=True) + lora_rank = gr.Number(value=16, label="LoRA rank", precision=0, interactive=True) + # save_lora_dir = gr.Text(value="./lora", label="LoRA model save path", interactive=True) + load_lora_path_0 = gr.Text(value="", label="LoRA model load path for image A", interactive=True) + load_lora_path_1 = gr.Text(value="", label="LoRA model load path for image B", interactive=True) + + def store_img(img): + image = Image.fromarray(img).convert("RGB").resize((512,512), Image.BILINEAR) + # resize the input to 512x512 + # image = image.resize((512,512), Image.BILINEAR) + # image = np.array(image) + # when new image is uploaded, `selected_points` should be empty + return image + input_img_0.upload( + store_img, + [input_img_0], + [original_image_0] + ) + input_img_1.upload( + store_img, + [input_img_1], + [original_image_1] + ) + + def clear(LENGTH): + return gr.Image.update(value=None, width=LENGTH, height=LENGTH), \ + gr.Image.update(value=None, width=LENGTH, height=LENGTH), \ + None, None, None, None + clear_button.click( + clear, + [gr.Number(value=LENGTH, visible=False, precision=0)], + [input_img_0, input_img_1, original_image_0, original_image_1, prompt_0, prompt_1] + ) + + train_lora_0_button.click( + train_lora_interface, + [ + original_image_0, + prompt_0, + model_path, + output_path, + lora_steps, + lora_rank, + lora_lr, + gr.Number(value=0, visible=False, precision=0) + ], + [lora_progress_bar] + ) + + train_lora_1_button.click( + train_lora_interface, + [ + original_image_1, + prompt_1, + model_path, + output_path, + lora_steps, + lora_rank, + lora_lr, + gr.Number(value=1, visible=False, precision=0) + ], + [lora_progress_bar] + ) + + run_button.click( + run_diffmorpher, + [ + original_image_0, + original_image_1, + prompt_0, + prompt_1, + model_path, + lora_mode, + lamb, + use_adain, + use_reschedule, + num_frames, + fps, + save_inter, + load_lora_path_0, + load_lora_path_1, + output_path + ], + [output_video] + ) + + run_all_button.click( + run_all, + [ + original_image_0, + original_image_1, + prompt_0, + prompt_1, + model_path, + lora_mode, + lamb, + use_adain, + use_reschedule, + num_frames, + fps, + save_inter, + load_lora_path_0, + load_lora_path_1, + output_path, + lora_steps, + lora_rank, + lora_lr + ], + [output_video] + ) + +demo.queue().launch(debug=True) diff --git a/assets/Biden.jpg b/assets/Biden.jpg new file mode 100644 index 0000000000000000000000000000000000000000..7ef6cbe14af8f340799f4fad7478f7e14f012e7d Binary files /dev/null and b/assets/Biden.jpg differ diff --git a/assets/Feifei.jpg b/assets/Feifei.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3ede96c0acaff03842d73047a7d1a2f81ec1797e Binary files /dev/null and b/assets/Feifei.jpg differ diff --git a/assets/Musk.jpg b/assets/Musk.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d13b4c682caead50a7b0ab98e4f91fcdd451c423 Binary files /dev/null and b/assets/Musk.jpg differ diff --git a/assets/Teaser.png b/assets/Teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..f81cff45c68c0830b979966a2c17e4a98ce3355b --- /dev/null +++ b/assets/Teaser.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5aadae7c6c1a0a6b36a91fbf3058bf0f699cfba252dd78e6595343ec5f5a5a08 +size 5698568 diff --git a/assets/Trump.jpg b/assets/Trump.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3401ebcd0529c68bd90fb09c6c5b4c8067d5d979 Binary files /dev/null and b/assets/Trump.jpg differ diff --git a/assets/cat.png b/assets/cat.png new file mode 100644 index 0000000000000000000000000000000000000000..ff87dacb35a5f1a89c253b4e33a15221fa0a2fde Binary files /dev/null and b/assets/cat.png differ diff --git a/assets/dog.png b/assets/dog.png new file mode 100644 index 0000000000000000000000000000000000000000..b115694b4e98b4c53cf34809446f47519016286d --- /dev/null +++ b/assets/dog.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:20f07d4f4e6c207426d516ddc3662572436a5a5d59cf85e47ca0d39d3a1cd252 +size 1561560 diff --git a/assets/dog_sit.png b/assets/dog_sit.png new file mode 100644 index 0000000000000000000000000000000000000000..f227e6bf9a30c17cd56798d51b2bcd21a7785e63 --- /dev/null +++ b/assets/dog_sit.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:754331bc083116d027e69e0177b3556a1cf5103c94a14d1fae292eb2a768d5b9 +size 1369335 diff --git a/assets/drag_realgirl0.png b/assets/drag_realgirl0.png new file mode 100644 index 0000000000000000000000000000000000000000..8fd76676718b10143d8651cddbdfeccf9735c354 Binary files /dev/null and b/assets/drag_realgirl0.png differ diff --git a/assets/drag_realgirl1.png b/assets/drag_realgirl1.png new file mode 100644 index 0000000000000000000000000000000000000000..8b091bb90644e300a6f368035c9907a271322583 Binary files /dev/null and b/assets/drag_realgirl1.png differ diff --git a/assets/drag_sculp0.png b/assets/drag_sculp0.png new file mode 100644 index 0000000000000000000000000000000000000000..2dc6a16063884276c70566339520226f743f78de Binary files /dev/null and b/assets/drag_sculp0.png differ diff --git a/assets/drag_sculp1.png b/assets/drag_sculp1.png new file mode 100644 index 0000000000000000000000000000000000000000..f2b5c4ae66f68a8bdb335a4a97d31a8b631f972c Binary files /dev/null and b/assets/drag_sculp1.png differ diff --git a/assets/fuji_0.jpg b/assets/fuji_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e4fef521c0c425a7c29ee16183d4c2de5ad08331 Binary files /dev/null and b/assets/fuji_0.jpg differ diff --git a/assets/fuji_1.jpg b/assets/fuji_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..56d78c7c7845f15864ad9a52b0d58b089eaf3b32 Binary files /dev/null and b/assets/fuji_1.jpg differ diff --git a/assets/house0.jpg b/assets/house0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9e673ffa20bbc394e1c4a777b878ed9eebd38300 Binary files /dev/null and b/assets/house0.jpg differ diff --git a/assets/house1.jpg b/assets/house1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..f7790b576b8e3d9c8ecd30933d35652dffc5aa1a Binary files /dev/null and b/assets/house1.jpg differ diff --git a/assets/jeep.jpg b/assets/jeep.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0b771ad7da28d257fb8ad1070fdfbfd7770f0d21 Binary files /dev/null and b/assets/jeep.jpg differ diff --git a/assets/leo_0.jpg b/assets/leo_0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..52b1decf9baf0c8f4e62949d216e169c8ef9514b Binary files /dev/null and b/assets/leo_0.jpg differ diff --git a/assets/leo_1.jpg b/assets/leo_1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..fe638d3970e6db66fc1948bdb6d67fbf861de727 Binary files /dev/null and b/assets/leo_1.jpg differ diff --git a/assets/lion.png b/assets/lion.png new file mode 100644 index 0000000000000000000000000000000000000000..87210057f3b2c053b93b1e23dd07897278b0b4dd Binary files /dev/null and b/assets/lion.png differ diff --git a/assets/man_paint.png b/assets/man_paint.png new file mode 100644 index 0000000000000000000000000000000000000000..c1ff9b3e36959bd607ebab1d8766e38b02381f1c Binary files /dev/null and b/assets/man_paint.png differ diff --git a/assets/mit.jpg b/assets/mit.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d63d72d696b488bdd28c92ffab12715a793122f8 Binary files /dev/null and b/assets/mit.jpg differ diff --git a/assets/monalisa.jpeg b/assets/monalisa.jpeg new file mode 100644 index 0000000000000000000000000000000000000000..204ce5309951c63dfb094e32c87e9aa841d7e361 Binary files /dev/null and b/assets/monalisa.jpeg differ diff --git a/assets/obama.jpg b/assets/obama.jpg new file mode 100644 index 0000000000000000000000000000000000000000..39e0419e6798239a566ecd91719e37f0bab27fef Binary files /dev/null and b/assets/obama.jpg differ diff --git a/assets/pearlgirl.jpg b/assets/pearlgirl.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cda10518150f732514f3a83c747fa15558376406 Binary files /dev/null and b/assets/pearlgirl.jpg differ diff --git a/assets/rabbit.png b/assets/rabbit.png new file mode 100644 index 0000000000000000000000000000000000000000..b38dd97901e485606cd74d058e7534386ae63ca5 Binary files /dev/null and b/assets/rabbit.png differ diff --git a/assets/sculp0.png b/assets/sculp0.png new file mode 100644 index 0000000000000000000000000000000000000000..fe3e9500a74da2e0572b1adbbab31f6f70af8b0b Binary files /dev/null and b/assets/sculp0.png differ diff --git a/assets/sculp1.png b/assets/sculp1.png new file mode 100644 index 0000000000000000000000000000000000000000..96f36bd465893cbda0209ad31321ba6efbf84f05 Binary files /dev/null and b/assets/sculp1.png differ diff --git a/assets/teaser.gif b/assets/teaser.gif new file mode 100644 index 0000000000000000000000000000000000000000..7d2196dd017d400b5c9949e4a99611b6999b8bb8 --- /dev/null +++ b/assets/teaser.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1588453d2980b0a25c64ba02429493a5b2181e0547b268386dea93b8163a8e51 +size 26695155 diff --git a/assets/thu.jpg b/assets/thu.jpg new file mode 100644 index 0000000000000000000000000000000000000000..835a48689bf6064d711bd2d45eb49de77d0dd826 Binary files /dev/null and b/assets/thu.jpg differ diff --git a/assets/tiger.png b/assets/tiger.png new file mode 100644 index 0000000000000000000000000000000000000000..45c2ef0515c2ef38bb605e33254d321121039d18 Binary files /dev/null and b/assets/tiger.png differ diff --git a/assets/van.jpg b/assets/van.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9c23044f471b5229c447358a475196be8b166230 Binary files /dev/null and b/assets/van.jpg differ diff --git a/assets/vangogh.jpg b/assets/vangogh.jpg new file mode 100644 index 0000000000000000000000000000000000000000..3f47dc248b1dab4eeceb8ab88a45a076a57c9f1f Binary files /dev/null and b/assets/vangogh.jpg differ diff --git a/assets/vangogh_hat.png b/assets/vangogh_hat.png new file mode 100644 index 0000000000000000000000000000000000000000..2b18ebcc7260a2b776ba5a96df3616ffeb8ee780 Binary files /dev/null and b/assets/vangogh_hat.png differ diff --git a/assets/wave_paint.png b/assets/wave_paint.png new file mode 100644 index 0000000000000000000000000000000000000000..d6f0f968ae992438744112956987b56c83793d3f Binary files /dev/null and b/assets/wave_paint.png differ diff --git a/assets/wave_real.jpg b/assets/wave_real.jpg new file mode 100644 index 0000000000000000000000000000000000000000..55f2e4ca8dd76a90bca144848678b71a091e23cf Binary files /dev/null and b/assets/wave_real.jpg differ diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..80bffb68ad7e3c9d2d314a8f6395eed5850d2908 --- /dev/null +++ b/main.py @@ -0,0 +1,98 @@ +import os +import torch +import numpy as np +import cv2 +from PIL import Image +from argparse import ArgumentParser +from model import DiffMorpherPipeline + +parser = ArgumentParser() +parser.add_argument( + "--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base", + help="Pretrained model to use (default: %(default)s)" +) +parser.add_argument( + "--image_path_0", type=str, default="", + help="Path of the first image (default: %(default)s)") +parser.add_argument( + "--prompt_0", type=str, default="", + help="Prompt of the second image (default: %(default)s)") +parser.add_argument( + "--image_path_1", type=str, default="", + help="Path of the first image (default: %(default)s)") +parser.add_argument( + "--prompt_1", type=str, default="", + help="Prompt of the second image (default: %(default)s)") +parser.add_argument( + "--output_path", type=str, default="./results", + help="Path of the output image (default: %(default)s)" +) +parser.add_argument( + "--save_lora_dir", type=str, default="./lora", + help="Path of the output lora directory (default: %(default)s)" +) +parser.add_argument( + "--load_lora_path_0", type=str, default="", + help="Path of the lora directory of the first image (default: %(default)s)" +) +parser.add_argument( + "--load_lora_path_1", type=str, default="", + help="Path of the lora directory of the second image (default: %(default)s)" +) +parser.add_argument( + "--use_adain", action="store_true", + help="Use AdaIN (default: %(default)s)" +) +parser.add_argument( + "--use_reschedule", action="store_true", + help="Use reschedule sampling (default: %(default)s)" +) +parser.add_argument( + "--lamb", type=float, default=0.6, + help="Lambda for self-attention replacement (default: %(default)s)" +) +parser.add_argument( + "--fix_lora_value", type=float, default=None, + help="Fix lora value (default: LoRA Interp., not fixed)" +) +parser.add_argument( + "--save_inter", action="store_true", + help="Save intermediate results (default: %(default)s)" +) +parser.add_argument( + "--num_frames", type=int, default=16, + help="Number of frames to generate (default: %(default)s)" +) +parser.add_argument( + "--duration", type=int, default=100, + help="Duration of each frame (default: %(default)s ms)" +) +parser.add_argument( + "--no_lora", action="store_true" +) + +args = parser.parse_args() + +os.makedirs(args.output_path, exist_ok=True) +pipeline = DiffMorpherPipeline.from_pretrained( + args.model_path, torch_dtype=torch.float32) +pipeline.to("cuda") +images = pipeline( + img_path_0=args.image_path_0, + img_path_1=args.image_path_1, + prompt_0=args.prompt_0, + prompt_1=args.prompt_1, + save_lora_dir=args.save_lora_dir, + load_lora_path_0=args.load_lora_path_0, + load_lora_path_1=args.load_lora_path_1, + use_adain=args.use_adain, + use_reschedule=args.use_reschedule, + lamd=args.lamb, + output_path=args.output_path, + num_frames=args.num_frames, + fix_lora=args.fix_lora_value, + save_intermediates=args.save_inter, + use_lora=not args.no_lora +) +images[0].save(f"{args.output_path}/output.gif", save_all=True, + append_images=images[1:], duration=args.duration, loop=0) diff --git a/model.py b/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1547c3c0f227c36f37bec5e6e5a27d4ace2b26da --- /dev/null +++ b/model.py @@ -0,0 +1,639 @@ +import os +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import AttnProcessor +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +import torch +import torch.nn.functional as F +import tqdm +import numpy as np +import safetensors +from PIL import Image +from torchvision import transforms +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from diffusers import StableDiffusionPipeline +from argparse import ArgumentParser +import inspect + +from utils.model_utils import get_img, slerp, do_replace_attn +from utils.lora_utils import train_lora, load_lora +from utils.alpha_scheduler import AlphaScheduler + + +class StoreProcessor(): + def __init__(self, original_processor, value_dict, name): + self.original_processor = original_processor + self.value_dict = value_dict + self.name = name + self.value_dict[self.name] = dict() + self.id = 0 + + def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs): + # Is self attention + if encoder_hidden_states is None: + self.value_dict[self.name][self.id] = hidden_states.detach() + self.id += 1 + res = self.original_processor(attn, hidden_states, *args, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs) + + return res + + +class LoadProcessor(): + def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6): + super().__init__() + self.original_processor = original_processor + self.name = name + self.img0_dict = img0_dict + self.img1_dict = img1_dict + self.alpha = alpha + self.beta = beta + self.lamd = lamd + self.id = 0 + + def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs): + # Is self attention + if encoder_hidden_states is None: + if self.id < 50 * self.lamd: + map0 = self.img0_dict[self.name][self.id] + map1 = self.img1_dict[self.name][self.id] + cross_map = self.beta * hidden_states + \ + (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1) + # cross_map = self.beta * hidden_states + \ + # (1 - self.beta) * slerp(map0, map1, self.alpha) + # cross_map = slerp(slerp(map0, map1, self.alpha), + # hidden_states, self.beta) + # cross_map = hidden_states + # cross_map = torch.cat( + # ((1 - self.alpha) * map0, self.alpha * map1), dim=1) + + res = self.original_processor(attn, hidden_states, *args, + encoder_hidden_states=cross_map, + attention_mask=attention_mask, + **kwargs) + else: + res = self.original_processor(attn, hidden_states, *args, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs) + + self.id += 1 + # if self.id == len(self.img0_dict[self.name]): + if self.id == len(self.img0_dict[self.name]): + self.id = 0 + else: + res = self.original_processor(attn, hidden_states, *args, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs) + + return res + + +class DiffMorpherPipeline(StableDiffusionPipeline): + + def __init__(self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder=None, + requires_safety_checker: bool = True, + ): + sig = inspect.signature(super().__init__) + params = sig.parameters + if 'image_encoder' in params: + super().__init__(vae, text_encoder, tokenizer, unet, scheduler, + safety_checker, feature_extractor, image_encoder, requires_safety_checker) + else: + super().__init__(vae, text_encoder, tokenizer, unet, scheduler, + safety_checker, feature_extractor, requires_safety_checker) + self.img0_dict = dict() + self.img1_dict = dict() + + def inv_step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + eta=0., + verbose=False + ): + """ + Inverse sampling for DDIM Inversion + """ + if verbose: + print("timestep: ", timestep) + next_step = timestep + timestep = min(timestep - self.scheduler.config.num_train_timesteps // + self.scheduler.num_inference_steps, 999) + alpha_prod_t = self.scheduler.alphas_cumprod[ + timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output + x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir + return x_next, pred_x0 + + @torch.no_grad() + def invert( + self, + image: torch.Tensor, + prompt, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=1., + eta=0.0, + **kwds): + """ + invert a real image into noise map with determinisc DDIM inversion + """ + DEVICE = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + batch_size = image.shape[0] + if isinstance(prompt, list): + if batch_size == 1: + image = image.expand(len(prompt), -1, -1, -1) + elif isinstance(prompt, str): + if batch_size > 1: + prompt = [prompt] * batch_size + + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] + print("input text embeddings :", text_embeddings.shape) + # define initial latents + latents = self.image2latent(image) + + # unconditional embedding for classifier free guidance + if guidance_scale > 1.: + max_length = text_input.input_ids.shape[-1] + unconditional_input = self.tokenizer( + [""] * batch_size, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_embeddings = self.text_encoder( + unconditional_input.input_ids.to(DEVICE))[0] + text_embeddings = torch.cat( + [unconditional_embeddings, text_embeddings], dim=0) + + print("latents shape: ", latents.shape) + # interative sampling + self.scheduler.set_timesteps(num_inference_steps) + print("Valid timesteps: ", reversed(self.scheduler.timesteps)) + # print("attributes: ", self.scheduler.__dict__) + latents_list = [latents] + pred_x0_list = [latents] + for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): + if num_actual_inference_steps is not None and i >= num_actual_inference_steps: + continue + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + + # predict the noise + noise_pred = self.unet( + model_inputs, t, encoder_hidden_states=text_embeddings).sample + if guidance_scale > 1.: + noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * \ + (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t-1 -> x_t + latents, pred_x0 = self.inv_step(noise_pred, t, latents) + latents_list.append(latents) + pred_x0_list.append(pred_x0) + + return latents + + @torch.no_grad() + def ddim_inversion(self, latent, cond): + timesteps = reversed(self.scheduler.timesteps) + with torch.autocast(device_type='cuda', dtype=torch.float32): + for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")): + cond_batch = cond.repeat(latent.shape[0], 1, 1) + + alpha_prod_t = self.scheduler.alphas_cumprod[t] + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[timesteps[i - 1]] + if i > 0 else self.scheduler.final_alpha_cumprod + ) + + mu = alpha_prod_t ** 0.5 + mu_prev = alpha_prod_t_prev ** 0.5 + sigma = (1 - alpha_prod_t) ** 0.5 + sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 + + eps = self.unet( + latent, t, encoder_hidden_states=cond_batch).sample + + pred_x0 = (latent - sigma_prev * eps) / mu_prev + latent = mu * pred_x0 + sigma * eps + # if save_latents: + # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) + # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) + return latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + ): + """ + predict the sample of the next step in the denoise process. + """ + prev_timestep = timestep - \ + self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = self.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output + x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir + return x_prev, pred_x0 + + @torch.no_grad() + def image2latent(self, image): + DEVICE = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + if type(image) is Image: + image = np.array(image) + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(2, 0, 1).unsqueeze(0) + # input image density range [-1, 1] + latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def latent2image(self, latents, return_type='np'): + latents = 1 / 0.18215 * latents.detach() + image = self.vae.decode(latents)['sample'] + if return_type == 'np': + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = (image * 255).astype(np.uint8) + elif return_type == "pt": + image = (image / 2 + 0.5).clamp(0, 1) + + return image + + def latent2image_grad(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents)['sample'] + + return image # range [-1, 1] + + @torch.no_grad() + def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None): + # latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \ + # torch.sin(alpha * torch.pi / 2) * img_noise_1 + # latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1 + # latents = latents / ((1 - alpha) ** 2 + alpha ** 2) + latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain) + text_embeddings = (1 - alpha) * text_embeddings_0 + \ + alpha * text_embeddings_1 + + self.scheduler.set_timesteps(num_inference_steps) + if use_lora: + if fix_lora is not None: + self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora) + else: + self.unet = load_lora(self.unet, lora_0, lora_1, alpha) + + for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")): + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + if unconditioning is not None and isinstance(unconditioning, list): + _, text_embeddings = text_embeddings.chunk(2) + text_embeddings = torch.cat( + [unconditioning[i].expand(*text_embeddings.shape), text_embeddings]) + # predict the noise + noise_pred = self.unet( + model_inputs, t, encoder_hidden_states=text_embeddings).sample + if guidance_scale > 1.0: + noise_pred_uncon, noise_pred_con = noise_pred.chunk( + 2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * \ + (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + return latents + + @torch.no_grad() + def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size): + DEVICE = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0] + + if guidance_scale > 1.: + if neg_prompt: + uc_text = neg_prompt + else: + uc_text = "" + unconditional_input = self.tokenizer( + [uc_text] * batch_size, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_embeddings = self.text_encoder( + unconditional_input.input_ids.to(DEVICE))[0] + text_embeddings = torch.cat( + [unconditional_embeddings, text_embeddings], dim=0) + + return text_embeddings + + def __call__( + self, + img_0=None, + img_1=None, + img_path_0=None, + img_path_1=None, + prompt_0="", + prompt_1="", + save_lora_dir="./lora", + load_lora_path_0=None, + load_lora_path_1=None, + lora_steps=200, + lora_lr=2e-4, + lora_rank=16, + batch_size=1, + height=512, + width=512, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=1, + attn_beta=0, + lamd=0.6, + use_lora=True, + use_adain=True, + use_reschedule=True, + output_path="./results", + num_frames=50, + fix_lora=None, + progress=tqdm, + unconditioning=None, + neg_prompt=None, + save_intermediates=False, + **kwds): + + # if isinstance(prompt, list): + # batch_size = len(prompt) + # elif isinstance(prompt, str): + # if batch_size > 1: + # prompt = [prompt] * batch_size + self.scheduler.set_timesteps(num_inference_steps) + self.use_lora = use_lora + self.use_adain = use_adain + self.use_reschedule = use_reschedule + self.output_path = output_path + + if img_0 is None: + img_0 = Image.open(img_path_0).convert("RGB") + # else: + # img_0 = Image.fromarray(img_0).convert("RGB") + + if img_1 is None: + img_1 = Image.open(img_path_1).convert("RGB") + # else: + # img_1 = Image.fromarray(img_1).convert("RGB") + + if self.use_lora: + print("Loading lora...") + if not load_lora_path_0: + + weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt" + load_lora_path_0 = save_lora_dir + "/" + weight_name + if not os.path.exists(load_lora_path_0): + train_lora(img_0, prompt_0, save_lora_dir, None, self.tokenizer, self.text_encoder, + self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name) + print(f"Load from {load_lora_path_0}.") + if load_lora_path_0.endswith(".safetensors"): + lora_0 = safetensors.torch.load_file( + load_lora_path_0, device="cpu") + else: + lora_0 = torch.load(load_lora_path_0, map_location="cpu") + + if not load_lora_path_1: + weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt" + load_lora_path_1 = save_lora_dir + "/" + weight_name + if not os.path.exists(load_lora_path_1): + train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder, + self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name) + print(f"Load from {load_lora_path_1}.") + if load_lora_path_1.endswith(".safetensors"): + lora_1 = safetensors.torch.load_file( + load_lora_path_1, device="cpu") + else: + lora_1 = torch.load(load_lora_path_1, map_location="cpu") + else: + lora_0 = lora_1 = None + + text_embeddings_0 = self.get_text_embeddings( + prompt_0, guidance_scale, neg_prompt, batch_size) + text_embeddings_1 = self.get_text_embeddings( + prompt_1, guidance_scale, neg_prompt, batch_size) + img_0 = get_img(img_0) + img_1 = get_img(img_1) + if self.use_lora: + self.unet = load_lora(self.unet, lora_0, lora_1, 0) + img_noise_0 = self.ddim_inversion( + self.image2latent(img_0), text_embeddings_0) + if self.use_lora: + self.unet = load_lora(self.unet, lora_0, lora_1, 1) + img_noise_1 = self.ddim_inversion( + self.image2latent(img_1), text_embeddings_1) + + print("latents shape: ", img_noise_0.shape) + + original_processor = list(self.unet.attn_processors.values())[0] + + def morph(alpha_list, progress, desc): + images = [] + if attn_beta is not None: + if self.use_lora: + self.unet = load_lora( + self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora) + + attn_processor_dict = {} + for k in self.unet.attn_processors.keys(): + if do_replace_attn(k): + if self.use_lora: + attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k], + self.img0_dict, k) + else: + attn_processor_dict[k] = StoreProcessor(original_processor, + self.img0_dict, k) + else: + attn_processor_dict[k] = self.unet.attn_processors[k] + self.unet.set_attn_processor(attn_processor_dict) + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[0], + False, + fix_lora + ) + first_image = self.latent2image(latents) + first_image = Image.fromarray(first_image) + if save_intermediates: + first_image.save(f"{self.output_path}/{0:02d}.png") + + if self.use_lora: + self.unet = load_lora( + self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora) + attn_processor_dict = {} + for k in self.unet.attn_processors.keys(): + if do_replace_attn(k): + if self.use_lora: + attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k], + self.img1_dict, k) + else: + attn_processor_dict[k] = StoreProcessor(original_processor, + self.img1_dict, k) + else: + attn_processor_dict[k] = self.unet.attn_processors[k] + + self.unet.set_attn_processor(attn_processor_dict) + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[-1], + False, + fix_lora + ) + last_image = self.latent2image(latents) + last_image = Image.fromarray(last_image) + if save_intermediates: + last_image.save( + f"{self.output_path}/{num_frames - 1:02d}.png") + + for i in progress.tqdm(range(1, num_frames - 1), desc=desc): + alpha = alpha_list[i] + if self.use_lora: + self.unet = load_lora( + self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora) + + attn_processor_dict = {} + for k in self.unet.attn_processors.keys(): + if do_replace_attn(k): + if self.use_lora: + attn_processor_dict[k] = LoadProcessor( + self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd) + else: + attn_processor_dict[k] = LoadProcessor( + original_processor, k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd) + else: + attn_processor_dict[k] = self.unet.attn_processors[k] + + self.unet.set_attn_processor(attn_processor_dict) + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[i], + False, + fix_lora + ) + image = self.latent2image(latents) + image = Image.fromarray(image) + if save_intermediates: + image.save(f"{self.output_path}/{i:02d}.png") + images.append(image) + + images = [first_image] + images + [last_image] + + else: + for k, alpha in enumerate(alpha_list): + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[k], + self.use_lora, + fix_lora + ) + image = self.latent2image(latents) + image = Image.fromarray(image) + if save_intermediates: + image.save(f"{self.output_path}/{k:02d}.png") + images.append(image) + + return images + + with torch.no_grad(): + if self.use_reschedule: + alpha_scheduler = AlphaScheduler() + alpha_list = list(torch.linspace(0, 1, num_frames)) + images_pt = morph(alpha_list, progress, "Sampling...") + images_pt = [transforms.ToTensor()(img).unsqueeze(0) + for img in images_pt] + alpha_scheduler.from_imgs(images_pt) + alpha_list = alpha_scheduler.get_list() + print(alpha_list) + images = morph(alpha_list, progress, "Reschedule..." + ) + else: + alpha_list = list(torch.linspace(0, 1, num_frames)) + print(alpha_list) + images = morph(alpha_list, progress, "Sampling...") + + return images diff --git a/multi_image/README.md b/multi_image/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c477e1151098448864219c304046c197d65761f2 --- /dev/null +++ b/multi_image/README.md @@ -0,0 +1,37 @@ +## Update + +Add support for multi-image input. Now you can get the morphing output among more than 2 images. + +## Run the code + +You can run the code with the following command: + +``` +python main.py \ + --image_paths [image_path_0] ... [image_path_n] \ + --prompts [prompt_0] ... [prompt_n] \ + --output_path [output_path] \ + --use_adain --use_reschedule --save_inter +``` + +This modification add support for the following options: + +- `--image_paths`: Paths of the input images +- `--prompts`: Prompts of the images +- `--load_lora_paths`: Paths of the lora directory of the images + +## Example + +Run the code: +``` +python main.py \ +--image_paths ./assets/realdog0.jpg ./assets/realdog1.jpg ./assets/realdog2.jpg \ +--prompts "A photo of a dog" "A photo of a dog" "A photo of a dog" \ +--output_path "./results/dog" \ +--use_adain --use_reschedule --save_inter +``` + +Output: +
+ +
\ No newline at end of file diff --git a/multi_image/assets/realdog.gif b/multi_image/assets/realdog.gif new file mode 100644 index 0000000000000000000000000000000000000000..2ca770cc1757cfdd2ca6447a7e6f20c7ffc73715 --- /dev/null +++ b/multi_image/assets/realdog.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a70a7919940a458f77ee8715f09d4ff5a944a19ea70d643e17ab2f555034ed6a +size 5214125 diff --git a/multi_image/assets/realdog0.jpg b/multi_image/assets/realdog0.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a74cf29ab8b624e00f6db0d38570d096a3e89908 Binary files /dev/null and b/multi_image/assets/realdog0.jpg differ diff --git a/multi_image/assets/realdog1.jpg b/multi_image/assets/realdog1.jpg new file mode 100644 index 0000000000000000000000000000000000000000..12c94866d10cc8172a64780ace97ce79eb8ef0cd Binary files /dev/null and b/multi_image/assets/realdog1.jpg differ diff --git a/multi_image/assets/realdog2.jpg b/multi_image/assets/realdog2.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d48a2372875231e805b512b3a435083d19ecd77a Binary files /dev/null and b/multi_image/assets/realdog2.jpg differ diff --git a/multi_image/main.py b/multi_image/main.py new file mode 100644 index 0000000000000000000000000000000000000000..ebbbafa8f88aed113996182a568c3feb7673ea52 --- /dev/null +++ b/multi_image/main.py @@ -0,0 +1,107 @@ +import os +import torch +import numpy as np +import cv2 +from PIL import Image +from argparse import ArgumentParser +from model import DiffMorpherPipeline + +parser = ArgumentParser() +parser.add_argument( + "--model_path", type=str, default="stabilityai/stable-diffusion-2-1-base", + help="Pretrained model to use (default: %(default)s)" +) +parser.add_argument( + "--image_path_0", type=str, default="", + help="Path of the first image (default: %(default)s)") +parser.add_argument( + "--prompt_0", type=str, default="", + help="Prompt of the second image (default: %(default)s)") +parser.add_argument( + "--image_path_1", type=str, default="", + help="Path of the first image (default: %(default)s)") +parser.add_argument( + "--prompt_1", type=str, default="", + help="Prompt of the second image (default: %(default)s)") +parser.add_argument( + "--load_lora_path_0", type=str, default="", + help="Path of the lora directory of the first image (default: %(default)s)" +) +parser.add_argument( + "--load_lora_path_1", type=str, default="", + help="Path of the lora directory of the second image (default: %(default)s)" +) +parser.add_argument( + "--image_paths", type=str, nargs='*', default=[], + help="Path of the first image (default: %(default)s)") +parser.add_argument( + "--prompts", type=str, nargs='*', default=[], + help="Prompt of the second image (default: %(default)s)") +parser.add_argument( + "--output_path", type=str, default="./results", + help="Path of the output image (default: %(default)s)" +) +parser.add_argument( + "--save_lora_dir", type=str, default="./lora", + help="Path of the output lora directory (default: %(default)s)" +) +parser.add_argument( + "--load_lora_paths", type=str, nargs='*', default=[], + help="Path of the lora directory of the first image (default: %(default)s)" +) +parser.add_argument( + "--use_adain", action="store_true", + help="Use AdaIN (default: %(default)s)" +) +parser.add_argument( + "--use_reschedule", action="store_true", + help="Use reschedule sampling (default: %(default)s)" +) +parser.add_argument( + "--lamb", type=float, default=0.6, + help="Lambda for self-attention replacement (default: %(default)s)" +) +parser.add_argument( + "--fix_lora_value", type=float, default=None, + help="Fix lora value (default: LoRA Interp., not fixed)" +) +parser.add_argument( + "--save_inter", action="store_true", + help="Save intermediate results (default: %(default)s)" +) +parser.add_argument( + "--num_frames", type=int, default=16, + help="Number of frames to generate (default: %(default)s)" +) +parser.add_argument( + "--duration", type=int, default=100, + help="Duration of each frame (default: %(default)s ms)" +) + +args = parser.parse_args() + +os.makedirs(args.output_path, exist_ok=True) +pipeline = DiffMorpherPipeline.from_pretrained( + args.model_path, torch_dtype=torch.float32) +pipeline.to("cuda") +images = pipeline( + img_path_0=args.image_path_0, + img_path_1=args.image_path_1, + prompt_0=args.prompt_0, + prompt_1=args.prompt_1, + load_lora_path_0=args.load_lora_path_0, + load_lora_path_1=args.load_lora_path_1, + img_paths=args.image_paths, + prompts=args.prompts, + save_lora_dir=args.save_lora_dir, + load_lora_paths=args.load_lora_paths, + use_adain=args.use_adain, + use_reschedule=args.use_reschedule, + lamb=args.lamb, + output_path=args.output_path, + num_frames=args.num_frames, + fix_lora=args.fix_lora_value, + save_intermediates=args.save_inter, +) +images[0].save(f"{args.output_path}/output.gif", save_all=True, + append_images=images[1:], duration=args.duration, loop=0) diff --git a/multi_image/model.py b/multi_image/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6ecb346dc40cd0a26e818a2db4c4eb0e837a7e20 --- /dev/null +++ b/multi_image/model.py @@ -0,0 +1,699 @@ +import os +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import AttnProcessor +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +import torch +import torch.nn.functional as F +import tqdm +import numpy as np +import safetensors +from PIL import Image +from torchvision import transforms +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer +from diffusers import StableDiffusionPipeline +from argparse import ArgumentParser + + +from utils.model_utils import get_img, slerp, do_replace_attn +from utils.lora_utils import train_lora, load_lora +from utils.alpha_scheduler import AlphaScheduler + +class StoreProcessor(): + def __init__(self, original_processor, value_dict, name): + self.original_processor = original_processor + self.value_dict = value_dict + self.name = name + self.value_dict[self.name] = dict() + self.id = 0 + + def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs): + # Is self attention + if encoder_hidden_states is None: + self.value_dict[self.name][self.id] = hidden_states.detach() + self.id += 1 + res = self.original_processor(attn, hidden_states, *args, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs) + + return res + + +class LoadProcessor(): + def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamd=0.6): + super().__init__() + self.original_processor = original_processor + self.name = name + self.img0_dict = img0_dict + self.img1_dict = img1_dict + self.alpha = alpha + self.beta = beta + self.lamd = lamd + self.id = 0 + + def parent_call( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 + ): + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size) + + if attn.group_norm is not None: + hidden_states = attn.group_norm( + hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + scale * \ + self.original_processor.to_q_lora(hidden_states) + query = attn.head_to_batch_dim(query) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + scale * \ + self.original_processor.to_k_lora(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + scale * \ + self.original_processor.to_v_lora(encoder_hidden_states) + + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + attention_probs = attn.get_attention_scores( + query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0]( + hidden_states) + scale * self.original_processor.to_out_lora(hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose( + -1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs): + # Is self attention + if encoder_hidden_states is None: + # hardcode timestep + if self.id < 50 * self.lamd: + map0 = self.img0_dict[self.name][self.id] + map1 = self.img1_dict[self.name][self.id] + cross_map = self.beta * hidden_states + \ + (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1) + # cross_map = self.beta * hidden_states + \ + # (1 - self.beta) * slerp(map0, map1, self.alpha) + # cross_map = slerp(slerp(map0, map1, self.alpha), + # hidden_states, self.beta) + # cross_map = hidden_states + # cross_map = torch.cat( + # ((1 - self.alpha) * map0, self.alpha * map1), dim=1) + + # res = self.original_processor(attn, hidden_states, *args, + # encoder_hidden_states=cross_map, + # attention_mask=attention_mask, + # temb=temb, **kwargs) + res = self.parent_call(attn, hidden_states, *args, + encoder_hidden_states=cross_map, + attention_mask=attention_mask, + **kwargs) + else: + res = self.original_processor(attn, hidden_states, *args, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs) + + self.id += 1 + # if self.id == len(self.img0_dict[self.name]): + if self.id == len(self.img0_dict[self.name]): + self.id = 0 + else: + res = self.original_processor(attn, hidden_states, *args, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + **kwargs) + + return res + + +class DiffMorpherPipeline(StableDiffusionPipeline): + + def __init__(self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + ): + super().__init__(vae, text_encoder, tokenizer, unet, scheduler, + safety_checker, feature_extractor, requires_safety_checker) + self.img0_dict = dict() + self.img1_dict = dict() + + def inv_step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + eta=0., + verbose=False + ): + """ + Inverse sampling for DDIM Inversion + """ + if verbose: + print("timestep: ", timestep) + next_step = timestep + timestep = min(timestep - self.scheduler.config.num_train_timesteps // + self.scheduler.num_inference_steps, 999) + alpha_prod_t = self.scheduler.alphas_cumprod[ + timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod + alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step] + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output + x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir + return x_next, pred_x0 + + @torch.no_grad() + def invert( + self, + image: torch.Tensor, + prompt, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=1., + eta=0.0, + **kwds): + """ + invert a real image into noise map with determinisc DDIM inversion + """ + DEVICE = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + batch_size = image.shape[0] + if isinstance(prompt, list): + if batch_size == 1: + image = image.expand(len(prompt), -1, -1, -1) + elif isinstance(prompt, str): + if batch_size > 1: + prompt = [prompt] * batch_size + + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0] + print("input text embeddings :", text_embeddings.shape) + # define initial latents + latents = self.image2latent(image) + + # unconditional embedding for classifier free guidance + if guidance_scale > 1.: + max_length = text_input.input_ids.shape[-1] + unconditional_input = self.tokenizer( + [""] * batch_size, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_embeddings = self.text_encoder( + unconditional_input.input_ids.to(DEVICE))[0] + text_embeddings = torch.cat( + [unconditional_embeddings, text_embeddings], dim=0) + + print("latents shape: ", latents.shape) + # interative sampling + self.scheduler.set_timesteps(num_inference_steps) + print("Valid timesteps: ", reversed(self.scheduler.timesteps)) + # print("attributes: ", self.scheduler.__dict__) + latents_list = [latents] + pred_x0_list = [latents] + for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")): + if num_actual_inference_steps is not None and i >= num_actual_inference_steps: + continue + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + + # predict the noise + noise_pred = self.unet( + model_inputs, t, encoder_hidden_states=text_embeddings).sample + if guidance_scale > 1.: + noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * \ + (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t-1 -> x_t + latents, pred_x0 = self.inv_step(noise_pred, t, latents) + latents_list.append(latents) + pred_x0_list.append(pred_x0) + + return latents + + @torch.no_grad() + def ddim_inversion(self, latent, cond): + timesteps = reversed(self.scheduler.timesteps) + with torch.autocast(device_type='cuda', dtype=torch.float32): + for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")): + cond_batch = cond.repeat(latent.shape[0], 1, 1) + + alpha_prod_t = self.scheduler.alphas_cumprod[t] + alpha_prod_t_prev = ( + self.scheduler.alphas_cumprod[timesteps[i - 1]] + if i > 0 else self.scheduler.final_alpha_cumprod + ) + + mu = alpha_prod_t ** 0.5 + mu_prev = alpha_prod_t_prev ** 0.5 + sigma = (1 - alpha_prod_t) ** 0.5 + sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 + + eps = self.unet( + latent, t, encoder_hidden_states=cond_batch).sample + + pred_x0 = (latent - sigma_prev * eps) / mu_prev + latent = mu * pred_x0 + sigma * eps + # if save_latents: + # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) + # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt')) + return latent + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + x: torch.FloatTensor, + ): + """ + predict the sample of the next step in the denoise process. + """ + prev_timestep = timestep - \ + self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps + alpha_prod_t = self.scheduler.alphas_cumprod[timestep] + alpha_prod_t_prev = self.scheduler.alphas_cumprod[ + prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5 + pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output + x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir + return x_prev, pred_x0 + + @torch.no_grad() + def image2latent(self, image): + DEVICE = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + if type(image) is Image: + image = np.array(image) + image = torch.from_numpy(image).float() / 127.5 - 1 + image = image.permute(2, 0, 1).unsqueeze(0) + # input image density range [-1, 1] + latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean + latents = latents * 0.18215 + return latents + + @torch.no_grad() + def latent2image(self, latents, return_type='np'): + latents = 1 / 0.18215 * latents.detach() + image = self.vae.decode(latents)['sample'] + if return_type == 'np': + image = (image / 2 + 0.5).clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).numpy()[0] + image = (image * 255).astype(np.uint8) + elif return_type == "pt": + image = (image / 2 + 0.5).clamp(0, 1) + + return image + + def latent2image_grad(self, latents): + latents = 1 / 0.18215 * latents + image = self.vae.decode(latents)['sample'] + + return image # range [-1, 1] + + @torch.no_grad() + def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None): + # latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \ + # torch.sin(alpha * torch.pi / 2) * img_noise_1 + # latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1 + # latents = latents / ((1 - alpha) ** 2 + alpha ** 2) + latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain) + text_embeddings = (1 - alpha) * text_embeddings_0 + \ + alpha * text_embeddings_1 + + self.scheduler.set_timesteps(num_inference_steps) + if use_lora: + if fix_lora is not None: + self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora) + else: + self.unet = load_lora(self.unet, lora_0, lora_1, alpha) + + for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")): + + if guidance_scale > 1.: + model_inputs = torch.cat([latents] * 2) + else: + model_inputs = latents + if unconditioning is not None and isinstance(unconditioning, list): + _, text_embeddings = text_embeddings.chunk(2) + text_embeddings = torch.cat( + [unconditioning[i].expand(*text_embeddings.shape), text_embeddings]) + # predict the noise + noise_pred = self.unet( + model_inputs, t, encoder_hidden_states=text_embeddings).sample + if guidance_scale > 1.0: + noise_pred_uncon, noise_pred_con = noise_pred.chunk( + 2, dim=0) + noise_pred = noise_pred_uncon + guidance_scale * \ + (noise_pred_con - noise_pred_uncon) + # compute the previous noise sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False)[0] + return latents + + @torch.no_grad() + def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size): + DEVICE = torch.device( + "cuda") if torch.cuda.is_available() else torch.device("cpu") + # text embeddings + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0] + + if guidance_scale > 1.: + if neg_prompt: + uc_text = neg_prompt + else: + uc_text = "" + unconditional_input = self.tokenizer( + [uc_text] * batch_size, + padding="max_length", + max_length=77, + return_tensors="pt" + ) + unconditional_embeddings = self.text_encoder( + unconditional_input.input_ids.to(DEVICE))[0] + text_embeddings = torch.cat( + [unconditional_embeddings, text_embeddings], dim=0) + + return text_embeddings + + def __call__( + self, + img_0=None, + img_1=None, + img_path_0=None, + img_path_1=None, + prompt_0="", + prompt_1="", + imgs=[], + img_paths=None, + prompts=None, + save_lora_dir="./lora", + load_lora_path_0=None, + load_lora_path_1=None, + load_lora_paths=None, + lora_steps=200, + lora_lr=2e-4, + lora_rank=16, + batch_size=1, + height=512, + width=512, + num_inference_steps=50, + num_actual_inference_steps=None, + guidance_scale=1, + attn_beta=0, + lamd=0.6, + use_lora=True, + use_adain=True, + use_reschedule=True, + output_path = "./results", + num_frames=50, + fix_lora=None, + progress=tqdm, + unconditioning=None, + neg_prompt=None, + save_intermediates=False, + **kwds): + + + self.scheduler.set_timesteps(num_inference_steps) + self.use_lora = use_lora + self.use_adain = use_adain + self.use_reschedule = use_reschedule + self.output_path = output_path + + + imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths] + assert len(prompts) == len(imgs) + + # if img_path_0 or img_0: + # img_paths = [img_path_0, img_path_1] + # prompts = [prompt_0, prompt_1] + # load_lora_paths = [load_lora_path_0, load_lora_path_1] + + # if img_0: + # imgs.append(Image.fromarray(img_0).convert("RGB")) + # if img_1: + # imgs.append(Image.fromarray(img_1).convert("RGB")) + # if imgs is None: + # imgs = [Image.open(img_path).convert("RGB") for img_path in img_paths] + # if len(prompts) < len(imgs): + # prompts += ["" for _ in range(len(imgs) - len(prompts))] + + if self.use_lora: + loras = [] + print("Loading lora...") + for i, (img, prompt) in enumerate(zip(imgs, prompts)): + if len(load_lora_paths) == i: + + weight_name = f"{output_path.split('/')[-1]}_lora_{i}.ckpt" + load_lora_paths.append(save_lora_dir + "/" + weight_name) + if not os.path.exists(load_lora_paths[i]): + train_lora(img, prompt, save_lora_dir, None, self.tokenizer, self.text_encoder, + self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name) + print(f"Load from {load_lora_paths[i]}.") + if load_lora_paths[i].endswith(".safetensors"): + loras.append(safetensors.torch.load_file( + load_lora_paths[i], device="cpu")) + else: + loras.append(torch.load(load_lora_paths[i], map_location="cpu")) + + # if not load_lora_path_1: + # weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt" + # load_lora_path_1 = save_lora_dir + "/" + weight_name + # if not os.path.exists(load_lora_path_1): + # train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder, + # self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name) + # print(f"Load from {load_lora_path_1}.") + # if load_lora_path_1.endswith(".safetensors"): + # lora_1 = safetensors.torch.load_file( + # load_lora_path_1, device="cpu") + # else: + # lora_1 = torch.load(load_lora_path_1, map_location="cpu") + + def morph(alpha_list, progress, desc, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1): + images = [] + if attn_beta is not None: + + self.unet = load_lora(self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora) + attn_processor_dict = {} + for k in self.unet.attn_processors.keys(): + if do_replace_attn(k): + attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k], + self.img0_dict, k) + else: + attn_processor_dict[k] = self.unet.attn_processors[k] + self.unet.set_attn_processor(attn_processor_dict) + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[0], + False, + fix_lora + ) + first_image = self.latent2image(latents) + first_image = Image.fromarray(first_image) + # if save_intermediates: + # first_image.save(f"{self.output_path}/{0:02d}.png") + + self.unet = load_lora(self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora) + attn_processor_dict = {} + for k in self.unet.attn_processors.keys(): + if do_replace_attn(k): + attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k], + self.img1_dict, k) + else: + attn_processor_dict[k] = self.unet.attn_processors[k] + + self.unet.set_attn_processor(attn_processor_dict) + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[-1], + False, + fix_lora + ) + last_image = self.latent2image(latents) + last_image = Image.fromarray(last_image) + # if save_intermediates: + # last_image.save( + # f"{self.output_path}/{num_frames - 1:02d}.png") + + for i in progress.tqdm(range(1, num_frames - 1), desc=desc): + alpha = alpha_list[i] + self.unet = load_lora(self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora) + attn_processor_dict = {} + for k in self.unet.attn_processors.keys(): + if do_replace_attn(k): + attn_processor_dict[k] = LoadProcessor( + self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamd) + else: + attn_processor_dict[k] = self.unet.attn_processors[k] + + self.unet.set_attn_processor(attn_processor_dict) + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[i], + False, + fix_lora + ) + image = self.latent2image(latents) + image = Image.fromarray(image) + # if save_intermediates: + # image.save(f"{self.output_path}/{i:02d}.png") + images.append(image) + + images = [first_image] + images + [last_image] + + else: + for k, alpha in enumerate(alpha_list): + + latents = self.cal_latent( + num_inference_steps, + guidance_scale, + unconditioning, + img_noise_0, + img_noise_1, + text_embeddings_0, + text_embeddings_1, + lora_0, + lora_1, + alpha_list[k], + self.use_lora, + fix_lora + ) + image = self.latent2image(latents) + image = Image.fromarray(image) + # if save_intermediates: + # image.save(f"{self.output_path}/{k:02d}.png") + images.append(image) + + return images + + images = [] + + for img_0, img_1, prompt_0, prompt_1, lora_0, lora_1 in zip(imgs[:-1], imgs[1:], prompts[:-1], prompts[1:], loras[:-1], loras[1:]): + text_embeddings_0 = self.get_text_embeddings( + prompt_0, guidance_scale, neg_prompt, batch_size) + text_embeddings_1 = self.get_text_embeddings( + prompt_1, guidance_scale, neg_prompt, batch_size) + img_0 = get_img(img_0) + img_1 = get_img(img_1) + if self.use_lora: + self.unet = load_lora(self.unet, lora_0, lora_1, 0) + img_noise_0 = self.ddim_inversion( + self.image2latent(img_0), text_embeddings_0) + if self.use_lora: + self.unet = load_lora(self.unet, lora_0, lora_1, 1) + img_noise_1 = self.ddim_inversion( + self.image2latent(img_1), text_embeddings_1) + + print("latents shape: ", img_noise_0.shape) + + with torch.no_grad(): + if self.use_reschedule: + alpha_scheduler = AlphaScheduler() + alpha_list = list(torch.linspace(0, 1, num_frames)) + images_pt = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1) + images_pt = [transforms.ToTensor()(img).unsqueeze(0) + for img in images_pt] + alpha_scheduler.from_imgs(images_pt) + alpha_list = alpha_scheduler.get_list() + print(alpha_list) + images_ = morph(alpha_list, progress, "Reschedule...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1) + else: + alpha_list = list(torch.linspace(0, 1, num_frames)) + print(alpha_list) + images_ = morph(alpha_list, progress, "Sampling...", img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1) + + if len(images) == 0: + images = images_ + else: + images += images_[1:] + + if save_intermediates: + for i, image in enumerate(images): + image.save(f"{self.output_path}/{i:02d}.png") + + return images diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4b626091f6a9a0d76dc47c1cc6ece7666dc0fb2b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,14 @@ +accelerate==0.23.0 +diffusers==0.17.1 +einops==0.7.0 +gradio==4.7.1 +numpy==1.26.1 +opencv_python==4.5.5.64 +packaging==23.2 +Pillow==10.1.0 +safetensors==0.4.0 +tqdm==4.65.0 +transformers==4.34.1 +torch +torchvision +lpips diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/alpha_scheduler.py b/utils/alpha_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..32d45e15fe4d30955e9d63c59955a8ccf1666011 --- /dev/null +++ b/utils/alpha_scheduler.py @@ -0,0 +1,54 @@ +import bisect +import torch +import torch.nn.functional as F +import lpips + +perceptual_loss = lpips.LPIPS() + + +def distance(img_a, img_b): + return perceptual_loss(img_a, img_b).item() + # return F.mse_loss(img_a, img_b).item() + + +class AlphaScheduler: + def __init__(self): + ... + + def from_imgs(self, imgs): + self.__num_values = len(imgs) + self.__values = [0] + for i in range(self.__num_values - 1): + dis = distance(imgs[i], imgs[i + 1]) + self.__values.append(dis) + self.__values[i + 1] += self.__values[i] + for i in range(self.__num_values): + self.__values[i] /= self.__values[-1] + + def save(self, filename): + torch.save(torch.tensor(self.__values), filename) + + def load(self, filename): + self.__values = torch.load(filename).tolist() + self.__num_values = len(self.__values) + + def get_x(self, y): + assert y >= 0 and y <= 1 + id = bisect.bisect_left(self.__values, y) + id -= 1 + if id < 0: + id = 0 + yl = self.__values[id] + yr = self.__values[id + 1] + xl = id * (1 / (self.__num_values - 1)) + xr = (id + 1) * (1 / (self.__num_values - 1)) + x = (y - yl) / (yr - yl) * (xr - xl) + xl + return x + + def get_list(self, len=None): + if len is None: + len = self.__num_values + + ys = torch.linspace(0, 1, len) + res = [self.get_x(y) for y in ys] + return res diff --git a/utils/lora_utils.py b/utils/lora_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e1fbd0240b0a728ef7117183ad20d66cc433f371 --- /dev/null +++ b/utils/lora_utils.py @@ -0,0 +1,283 @@ +from timeit import default_timer as timer +from datetime import timedelta +from PIL import Image +import os +import numpy as np +from einops import rearrange +import torch +import torch.nn.functional as F +from torchvision import transforms +import transformers +from accelerate import Accelerator +from accelerate.utils import set_seed +from packaging import version +from PIL import Image +import tqdm + +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DiffusionPipeline, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin +from diffusers.models.attention_processor import ( + AttnAddedKVProcessor, + AttnAddedKVProcessor2_0, + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + SlicedAttnAddedKVProcessor, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.17.0") + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + +# model_path: path of the model +# image: input image, have not been pre-processed +# save_lora_dir: the path to save the lora +# prompt: the user input prompt +# lora_steps: number of lora training step +# lora_lr: learning rate of lora training +# lora_rank: the rank of lora +def train_lora(image, prompt, save_lora_dir, model_path=None, tokenizer=None, text_encoder=None, vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm): + # initialize accelerator + accelerator = Accelerator( + gradient_accumulation_steps=1, + # mixed_precision='fp16' + ) + set_seed(0) + + # Load the tokenizer + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained( + model_path, + subfolder="tokenizer", + revision=None, + use_fast=False, + ) + # initialize the model + if noise_scheduler is None: + noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler") + if text_encoder is None: + text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None) + text_encoder = text_encoder_cls.from_pretrained( + model_path, subfolder="text_encoder", revision=None + ) + if vae is None: + vae = AutoencoderKL.from_pretrained( + model_path, subfolder="vae", revision=None + ) + if unet is None: + unet = UNet2DConditionModel.from_pretrained( + model_path, subfolder="unet", revision=None + ) + + # set device and dtype + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.requires_grad_(False) + + unet.to(device) + vae.to(device) + text_encoder.to(device) + + # initialize UNet LoRA + unet_lora_attn_procs = {} + for name, attn_processor in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + else: + raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks") + + if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): + lora_attn_processor_class = LoRAAttnAddedKVProcessor + else: + lora_attn_processor_class = ( + LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor + ) + unet_lora_attn_procs[name] = lora_attn_processor_class( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank + ) + unet.set_attn_processor(unet_lora_attn_procs) + unet_lora_layers = AttnProcsLayers(unet.attn_processors) + + # Optimizer creation + params_to_optimize = (unet_lora_layers.parameters()) + optimizer = torch.optim.AdamW( + params_to_optimize, + lr=lora_lr, + betas=(0.9, 0.999), + weight_decay=1e-2, + eps=1e-08, + ) + + lr_scheduler = get_scheduler( + "constant", + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=lora_steps, + num_cycles=1, + power=1.0, + ) + + # prepare accelerator + unet_lora_layers = accelerator.prepare_model(unet_lora_layers) + optimizer = accelerator.prepare_optimizer(optimizer) + lr_scheduler = accelerator.prepare_scheduler(lr_scheduler) + + # initialize text embeddings + with torch.no_grad(): + text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None) + text_embedding = encode_prompt( + text_encoder, + text_inputs.input_ids, + text_inputs.attention_mask, + text_encoder_use_attention_mask=False + ) + + if type(image) == np.ndarray: + image = Image.fromarray(image) + + # initialize latent distribution + image_transforms = transforms.Compose( + [ + transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), + # transforms.RandomCrop(512), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + image = image_transforms(image).to(device) + image = image.unsqueeze(dim=0) + + latents_dist = vae.encode(image).latent_dist + for _ in progress.tqdm(range(lora_steps), desc="Training LoRA..."): + unet.train() + model_input = latents_dist.sample() * vae.config.scaling_factor + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Predict the noise residual + model_pred = unet(noisy_model_input, timesteps, text_embedding).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # save the trained lora + # unet = unet.to(torch.float32) + # vae = vae.to(torch.float32) + # text_encoder = text_encoder.to(torch.float32) + + # unwrap_model is used to remove all special modules added when doing distributed training + # so here, there is no need to call unwrap_model + # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + LoraLoaderMixin.save_lora_weights( + save_directory=save_lora_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=None, + weight_name=weight_name, + safe_serialization=safe_serialization + ) + +def load_lora(unet, lora_0, lora_1, alpha): + lora = {} + for key in lora_0: + lora[key] = (1 - alpha) * lora_0[key] + alpha * lora_1[key] + unet.load_attn_procs(lora) + return unet diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..48ae89ad39e86b26470ab2dc974fbd80c4543df4 --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,86 @@ +import torch +import torch.nn.functional as F +from torchvision import transforms + +def calc_mean_std(feat, eps=1e-5): + # eps is a small value added to the variance to avoid divide-by-zero. + size = feat.size() + + N, C = size[:2] + feat_var = feat.view(N, C, -1).var(dim=2) + eps + if len(size) == 3: + feat_std = feat_var.sqrt().view(N, C, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1) + else: + feat_std = feat_var.sqrt().view(N, C, 1, 1) + feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) + return feat_mean, feat_std + + +def get_img(img, resolution=512): + norm_mean = [0.5, 0.5, 0.5] + norm_std = [0.5, 0.5, 0.5] + transform = transforms.Compose([ + transforms.Resize((resolution, resolution)), + transforms.ToTensor(), + transforms.Normalize(norm_mean, norm_std) + ]) + img = transform(img) + return img.unsqueeze(0) + +@torch.no_grad() +def slerp(p0, p1, fract_mixing: float, adain=True): + r""" Copied from lunarring/latentblending + Helper function to correctly mix two random variables using spherical interpolation. + The function will always cast up to float64 for sake of extra 4. + Args: + p0: + First tensor for interpolation + p1: + Second tensor for interpolation + fract_mixing: float + Mixing coefficient of interval [0, 1]. + 0 will return in p0 + 1 will return in p1 + 0.x will return a mix between both preserving angular velocity. + """ + if p0.dtype == torch.float16: + recast_to = 'fp16' + else: + recast_to = 'fp32' + + p0 = p0.double() + p1 = p1.double() + + if adain: + mean1, std1 = calc_mean_std(p0) + mean2, std2 = calc_mean_std(p1) + mean = mean1 * (1 - fract_mixing) + mean2 * fract_mixing + std = std1 * (1 - fract_mixing) + std2 * fract_mixing + + norm = torch.linalg.norm(p0) * torch.linalg.norm(p1) + epsilon = 1e-7 + dot = torch.sum(p0 * p1) / norm + dot = dot.clamp(-1+epsilon, 1-epsilon) + + theta_0 = torch.arccos(dot) + sin_theta_0 = torch.sin(theta_0) + theta_t = theta_0 * fract_mixing + s0 = torch.sin(theta_0 - theta_t) / sin_theta_0 + s1 = torch.sin(theta_t) / sin_theta_0 + interp = p0*s0 + p1*s1 + + if adain: + interp = F.instance_norm(interp) * std + mean + + if recast_to == 'fp16': + interp = interp.half() + elif recast_to == 'fp32': + interp = interp.float() + + return interp + + +def do_replace_attn(key: str): + # return key.startswith('up_blocks.2') or key.startswith('up_blocks.3') + return key.startswith('up')