Kevin commited on
Commit
6ee2eb6
·
1 Parent(s): 34b176d
Files changed (8) hide show
  1. .gitignore +4 -0
  2. Biden.jpg +0 -0
  3. Trump.jpg +0 -0
  4. alpha_scheduler.py +54 -0
  5. app.py +306 -0
  6. lora_utils.py +318 -0
  7. morph_attn.py +827 -0
  8. requirements.txt +14 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ lora/
2
+ __pycache__/
3
+ results/
4
+ core*
Biden.jpg ADDED
Trump.jpg ADDED
alpha_scheduler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import lpips
5
+
6
+ perceptual_loss = lpips.LPIPS()
7
+
8
+
9
+ def distance(img_a, img_b):
10
+ return perceptual_loss(img_a, img_b).item()
11
+ # return F.mse_loss(img_a, img_b).item()
12
+
13
+
14
+ class AlphaScheduler:
15
+ def __init__(self):
16
+ ...
17
+
18
+ def from_imgs(self, imgs):
19
+ self.__num_values = len(imgs)
20
+ self.__values = [0]
21
+ for i in range(self.__num_values - 1):
22
+ dis = distance(imgs[i], imgs[i + 1])
23
+ self.__values.append(dis)
24
+ self.__values[i + 1] += self.__values[i]
25
+ for i in range(self.__num_values):
26
+ self.__values[i] /= self.__values[-1]
27
+
28
+ def save(self, filename):
29
+ torch.save(torch.tensor(self.__values), filename)
30
+
31
+ def load(self, filename):
32
+ self.__values = torch.load(filename).tolist()
33
+ self.__num_values = len(self.__values)
34
+
35
+ def get_x(self, y):
36
+ assert y >= 0 and y <= 1
37
+ id = bisect.bisect_left(self.__values, y)
38
+ id -= 1
39
+ if id < 0:
40
+ id = 0
41
+ yl = self.__values[id]
42
+ yr = self.__values[id + 1]
43
+ xl = id * (1 / (self.__num_values - 1))
44
+ xr = (id + 1) * (1 / (self.__num_values - 1))
45
+ x = (y - yl) / (yr - yl) * (xr - xl) + xl
46
+ return x
47
+
48
+ def get_list(self, len=None):
49
+ if len is None:
50
+ len = self.__num_values
51
+
52
+ ys = torch.linspace(0, 1, len)
53
+ res = [self.get_x(y) for y in ys]
54
+ return res
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import cv2
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from datetime import datetime
8
+ from morph_attn import DiffMorpherPipeline
9
+ from lora_utils import train_lora
10
+
11
+ LENGTH=480
12
+
13
+ def train_lora_interface(
14
+ image,
15
+ prompt,
16
+ model_path,
17
+ output_path,
18
+ lora_steps,
19
+ lora_rank,
20
+ lora_lr,
21
+ num
22
+ ):
23
+ os.makedirs(output_path, exist_ok=True)
24
+ train_lora(image, prompt, output_path, model_path,
25
+ lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_{num}.ckpt", progress=gr.Progress())
26
+ return f"Train LoRA {'A' if num == 0 else 'B'} Done!"
27
+
28
+ def run_diffmorpher(
29
+ image_0,
30
+ image_1,
31
+ prompt_0,
32
+ prompt_1,
33
+ model_path,
34
+ lora_mode,
35
+ lamb,
36
+ use_adain,
37
+ use_reschedule,
38
+ num_frames,
39
+ fps,
40
+ load_lora_path_0,
41
+ load_lora_path_1,
42
+ output_path
43
+ ):
44
+ run_id = datetime.now().strftime("%H%M") + "_" + datetime.now().strftime("%Y%m%d")
45
+ os.makedirs(output_path, exist_ok=True)
46
+ morpher_pipeline = DiffMorpherPipeline.from_pretrained(model_path, torch_dtype=torch.float32).to("cuda")
47
+ if lora_mode == "Fix LoRA 0":
48
+ fix_lora = 0
49
+ elif lora_mode == "Fix LoRA 1":
50
+ fix_lora = 1
51
+ else:
52
+ fix_lora = None
53
+ if not load_lora_path_0:
54
+ load_lora_path_0 = f"{output_path}/lora_0.ckpt"
55
+ if not load_lora_path_1:
56
+ load_lora_path_1 = f"{output_path}/lora_1.ckpt"
57
+ images = morpher_pipeline(
58
+ img_0=image_0,
59
+ img_1=image_1,
60
+ prompt_0=prompt_0,
61
+ prompt_1=prompt_1,
62
+ load_lora_path_0=load_lora_path_0,
63
+ load_lora_path_1=load_lora_path_1,
64
+ lamb=lamb,
65
+ use_adain=use_adain,
66
+ use_reschedule=use_reschedule,
67
+ num_frames=num_frames,
68
+ fix_lora=fix_lora,
69
+ progress=gr.Progress()
70
+ )
71
+ video_path = f"{output_path}/{run_id}.mp4"
72
+ video = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (LENGTH, LENGTH))
73
+ for image in images:
74
+ video.write(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
75
+ video.release()
76
+ cv2.destroyAllWindows()
77
+ return output_video.update(value=video_path)
78
+
79
+ def run_all(
80
+ image_0,
81
+ image_1,
82
+ prompt_0,
83
+ prompt_1,
84
+ model_path,
85
+ lora_mode,
86
+ lamb,
87
+ use_adain,
88
+ use_reschedule,
89
+ num_frames,
90
+ fps,
91
+ load_lora_path_0,
92
+ load_lora_path_1,
93
+ output_path,
94
+ lora_steps,
95
+ lora_rank,
96
+ lora_lr
97
+ ):
98
+ os.makedirs(output_path, exist_ok=True)
99
+ train_lora(image_0, prompt_0, output_path, model_path,
100
+ lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_0.ckpt", progress=gr.Progress())
101
+ train_lora(image_1, prompt_1, output_path, model_path,
102
+ lora_steps=lora_steps, lora_lr=lora_lr, lora_rank=lora_rank, weight_name=f"lora_1.ckpt", progress=gr.Progress())
103
+ return run_diffmorpher(
104
+ image_0,
105
+ image_1,
106
+ prompt_0,
107
+ prompt_1,
108
+ model_path,
109
+ lora_mode,
110
+ lamb,
111
+ use_adain,
112
+ use_reschedule,
113
+ num_frames,
114
+ fps,
115
+ load_lora_path_0,
116
+ load_lora_path_1,
117
+ output_path
118
+ )
119
+
120
+ with gr.Blocks() as demo:
121
+
122
+ with gr.Row():
123
+ gr.Markdown("""
124
+ # Official Implementation of [DiffMorpher](https://kevin-thu.github.io/DiffMorpher_page/)
125
+ """)
126
+
127
+ original_image_0, original_image_1 = gr.State(Image.open("Trump.jpg").convert("RGB").resize((512,512), Image.BILINEAR)), gr.State(Image.open("Biden.jpg").convert("RGB").resize((512,512), Image.BILINEAR))
128
+ # key_points_0, key_points_1 = gr.State([]), gr.State([])
129
+ # to_change_points = gr.State([])
130
+
131
+ with gr.Row():
132
+ with gr.Column():
133
+ input_img_0 = gr.Image(type="numpy", label="Input image A", value="Trump.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
134
+ prompt_0 = gr.Textbox(label="Prompt for image A", value="a photo of an American man", interactive=True)
135
+ with gr.Row():
136
+ train_lora_0_button = gr.Button("Train LoRA A")
137
+ train_lora_1_button = gr.Button("Train LoRA B")
138
+ # show_correspond_button = gr.Button("Show correspondence points")
139
+ with gr.Column():
140
+ input_img_1 = gr.Image(type="numpy", label="Input image B ", value="Biden.jpg", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
141
+ prompt_1 = gr.Textbox(label="Prompt for image B", value="a photo of an American man", interactive=True)
142
+ with gr.Row():
143
+ clear_button = gr.Button("Clear All")
144
+ run_button = gr.Button("Run w/o LoRA training")
145
+ with gr.Column():
146
+ output_video = gr.Video(format="mp4", label="Output video", show_label=True, height=LENGTH, width=LENGTH, interactive=False)
147
+ lora_progress_bar = gr.Textbox(label="Display LoRA training progress", interactive=False)
148
+ run_all_button = gr.Button("Run!")
149
+ # with gr.Column():
150
+ # output_video = gr.Video(label="Output video", show_label=True, height=LENGTH, width=LENGTH)
151
+
152
+ with gr.Row():
153
+ gr.Markdown("""
154
+ ### Usage:
155
+ 1. Upload two images (with correspondence) and fill out the prompts.
156
+ 2. Click **"Run!"**
157
+
158
+ Or:
159
+ 1. Upload two images (with correspondence) and fill out the prompts.
160
+ 2. Click the **"Train LoRA A/B"** button to fit two LoRAs for two images respectively. <br> &nbsp;&nbsp;
161
+ If you have trained LoRA A or LoRA B before, you can skip the step and fill the specific LoRA path in LoRA settings. <br> &nbsp;&nbsp;
162
+ Trained LoRAs are saved to `[Output Path]/lora_0.ckpt` and `[Output Path]/lora_1.ckpt` by default.
163
+ 3. You might also change the settings below.
164
+ 4. Click **"Run w/o LoRA training"**
165
+
166
+ ### Note:
167
+ 1. To speed up the generation process, you can **ruduce the number of frames** or **turn off "Use Reschedule"** ("Use Reschedule" will double the generation time).
168
+ 2. You can try the influence of different prompts. It seems that using the same prompts or aligned prompts works better.
169
+ ### Have fun!
170
+ """)
171
+
172
+ with gr.Accordion(label="Algorithm Parameters"):
173
+ with gr.Tab("Basic Settings"):
174
+ with gr.Row():
175
+ # local_models_dir = 'local_pretrained_models'
176
+ # local_models_choice = \
177
+ # [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
178
+ model_path = gr.Text(value="stabilityai/stable-diffusion-2-1-base",
179
+ label="Diffusion Model Path", interactive=True
180
+ )
181
+ lamb = gr.Slider(value=0.6, minimum=0, maximum=1, step=0.1, label="Lambda for attention replacement", interactive=True)
182
+ lora_mode = gr.Dropdown(value="LoRA Interp",
183
+ label="LoRA Interp. or Fix LoRA",
184
+ choices=["LoRA Interp", "Fix LoRA A", "Fix LoRA B"],
185
+ interactive=True
186
+ )
187
+ use_adain = gr.Checkbox(value=True, label="Use AdaIN", interactive=True)
188
+ use_reschedule = gr.Checkbox(value=True, label="Use Reschedule", interactive=True)
189
+ with gr.Row():
190
+ num_frames = gr.Number(value=15, minimum=0, label="Number of Frames", precision=0, interactive=True)
191
+ fps = gr.Number(value=8, minimum=0, label="FPS (Frame rate)", precision=0, interactive=True)
192
+ output_path = gr.Text(value="./results", label="Output Path", interactive=True)
193
+
194
+ with gr.Tab("LoRA Settings"):
195
+ with gr.Row():
196
+ lora_steps = gr.Number(value=200, label="LoRA training steps", precision=0, interactive=True)
197
+ lora_lr = gr.Number(value=0.0002, label="LoRA learning rate", interactive=True)
198
+ lora_rank = gr.Number(value=16, label="LoRA rank", precision=0, interactive=True)
199
+ # save_lora_dir = gr.Text(value="./lora", label="LoRA model save path", interactive=True)
200
+ load_lora_path_0 = gr.Text(value="", label="LoRA model load path for image A", interactive=True)
201
+ load_lora_path_1 = gr.Text(value="", label="LoRA model load path for image B", interactive=True)
202
+
203
+ def store_img(img):
204
+ image = Image.fromarray(img).convert("RGB").resize((512,512), Image.BILINEAR)
205
+ # resize the input to 512x512
206
+ # image = image.resize((512,512), Image.BILINEAR)
207
+ # image = np.array(image)
208
+ # when new image is uploaded, `selected_points` should be empty
209
+ return image
210
+ input_img_0.upload(
211
+ store_img,
212
+ [input_img_0],
213
+ [original_image_0]
214
+ )
215
+ input_img_1.upload(
216
+ store_img,
217
+ [input_img_1],
218
+ [original_image_1]
219
+ )
220
+
221
+ def clear(LENGTH):
222
+ return gr.Image.update(value=None, width=LENGTH, height=LENGTH), \
223
+ gr.Image.update(value=None, width=LENGTH, height=LENGTH), \
224
+ None, None, None, None
225
+ clear_button.click(
226
+ clear,
227
+ [gr.Number(value=LENGTH, visible=False, precision=0)],
228
+ [input_img_0, input_img_1, original_image_0, original_image_1, prompt_0, prompt_1]
229
+ )
230
+
231
+ train_lora_0_button.click(
232
+ train_lora_interface,
233
+ [
234
+ original_image_0,
235
+ prompt_0,
236
+ model_path,
237
+ output_path,
238
+ lora_steps,
239
+ lora_rank,
240
+ lora_lr,
241
+ gr.Number(value=0, visible=False, precision=0)
242
+ ],
243
+ [lora_progress_bar]
244
+ )
245
+
246
+ train_lora_1_button.click(
247
+ train_lora_interface,
248
+ [
249
+ original_image_1,
250
+ prompt_1,
251
+ model_path,
252
+ output_path,
253
+ lora_steps,
254
+ lora_rank,
255
+ lora_lr,
256
+ gr.Number(value=1, visible=False, precision=0)
257
+ ],
258
+ [lora_progress_bar]
259
+ )
260
+
261
+ run_button.click(
262
+ run_diffmorpher,
263
+ [
264
+ original_image_0,
265
+ original_image_1,
266
+ prompt_0,
267
+ prompt_1,
268
+ model_path,
269
+ lora_mode,
270
+ lamb,
271
+ use_adain,
272
+ use_reschedule,
273
+ num_frames,
274
+ fps,
275
+ load_lora_path_0,
276
+ load_lora_path_1,
277
+ output_path
278
+ ],
279
+ [output_video]
280
+ )
281
+
282
+ run_all_button.click(
283
+ run_all,
284
+ [
285
+ original_image_0,
286
+ original_image_1,
287
+ prompt_0,
288
+ prompt_1,
289
+ model_path,
290
+ lora_mode,
291
+ lamb,
292
+ use_adain,
293
+ use_reschedule,
294
+ num_frames,
295
+ fps,
296
+ load_lora_path_0,
297
+ load_lora_path_1,
298
+ output_path,
299
+ lora_steps,
300
+ lora_rank,
301
+ lora_lr
302
+ ],
303
+ [output_video]
304
+ )
305
+
306
+ demo.queue().launch(debug=True)
lora_utils.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timeit import default_timer as timer
2
+ from datetime import timedelta
3
+ from PIL import Image
4
+ import os
5
+ import numpy as np
6
+ from einops import rearrange
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision import transforms
10
+ import transformers
11
+ from accelerate import Accelerator
12
+ from accelerate.utils import set_seed
13
+ from packaging import version
14
+ from PIL import Image
15
+ import tqdm
16
+
17
+ from transformers import AutoTokenizer, PretrainedConfig
18
+
19
+ import diffusers
20
+ from diffusers import (
21
+ AutoencoderKL,
22
+ DDPMScheduler,
23
+ DiffusionPipeline,
24
+ DPMSolverMultistepScheduler,
25
+ StableDiffusionPipeline,
26
+ UNet2DConditionModel,
27
+ )
28
+ from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
29
+ from diffusers.models.attention_processor import (
30
+ AttnAddedKVProcessor,
31
+ AttnAddedKVProcessor2_0,
32
+ LoRAAttnAddedKVProcessor,
33
+ LoRAAttnProcessor,
34
+ LoRAAttnProcessor2_0,
35
+ SlicedAttnAddedKVProcessor,
36
+ )
37
+ from diffusers.optimization import get_scheduler
38
+ from diffusers.utils import check_min_version
39
+ from diffusers.utils.import_utils import is_xformers_available
40
+
41
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
42
+ check_min_version("0.17.0")
43
+
44
+
45
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
46
+ text_encoder_config = PretrainedConfig.from_pretrained(
47
+ pretrained_model_name_or_path,
48
+ subfolder="text_encoder",
49
+ revision=revision,
50
+ )
51
+ model_class = text_encoder_config.architectures[0]
52
+
53
+ if model_class == "CLIPTextModel":
54
+ from transformers import CLIPTextModel
55
+
56
+ return CLIPTextModel
57
+ elif model_class == "RobertaSeriesModelWithTransformation":
58
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
59
+
60
+ return RobertaSeriesModelWithTransformation
61
+ elif model_class == "T5EncoderModel":
62
+ from transformers import T5EncoderModel
63
+
64
+ return T5EncoderModel
65
+ else:
66
+ raise ValueError(f"{model_class} is not supported.")
67
+
68
+ def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
69
+ if tokenizer_max_length is not None:
70
+ max_length = tokenizer_max_length
71
+ else:
72
+ max_length = tokenizer.model_max_length
73
+
74
+ text_inputs = tokenizer(
75
+ prompt,
76
+ truncation=True,
77
+ padding="max_length",
78
+ max_length=max_length,
79
+ return_tensors="pt",
80
+ )
81
+
82
+ return text_inputs
83
+
84
+ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=False):
85
+ text_input_ids = input_ids.to(text_encoder.device)
86
+
87
+ if text_encoder_use_attention_mask:
88
+ attention_mask = attention_mask.to(text_encoder.device)
89
+ else:
90
+ attention_mask = None
91
+
92
+ prompt_embeds = text_encoder(
93
+ text_input_ids,
94
+ attention_mask=attention_mask,
95
+ )
96
+ prompt_embeds = prompt_embeds[0]
97
+
98
+ return prompt_embeds
99
+
100
+ # model_path: path of the model
101
+ # image: input image, have not been pre-processed
102
+ # save_lora_dir: the path to save the lora
103
+ # prompt: the user input prompt
104
+ # lora_steps: number of lora training step
105
+ # lora_lr: learning rate of lora training
106
+ # lora_rank: the rank of lora
107
+ def train_lora(image, prompt, save_lora_dir, model_path=None, tokenizer=None, text_encoder=None, vae=None, unet=None, noise_scheduler=None, lora_steps=200, lora_lr=2e-4, lora_rank=16, weight_name=None, safe_serialization=False, progress=tqdm):
108
+ # initialize accelerator
109
+ accelerator = Accelerator(
110
+ gradient_accumulation_steps=1,
111
+ # mixed_precision='fp16'
112
+ )
113
+ set_seed(0)
114
+
115
+ # Load the tokenizer
116
+ if tokenizer is None:
117
+ tokenizer = AutoTokenizer.from_pretrained(
118
+ model_path,
119
+ subfolder="tokenizer",
120
+ revision=None,
121
+ use_fast=False,
122
+ )
123
+ # initialize the model
124
+ if noise_scheduler is None:
125
+ noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
126
+ if text_encoder is None:
127
+ text_encoder_cls = import_model_class_from_model_name_or_path(model_path, revision=None)
128
+ text_encoder = text_encoder_cls.from_pretrained(
129
+ model_path, subfolder="text_encoder", revision=None
130
+ )
131
+ if vae is None:
132
+ vae = AutoencoderKL.from_pretrained(
133
+ model_path, subfolder="vae", revision=None
134
+ )
135
+ if unet is None:
136
+ unet = UNet2DConditionModel.from_pretrained(
137
+ model_path, subfolder="unet", revision=None
138
+ )
139
+
140
+ # set device and dtype
141
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
142
+
143
+ vae.requires_grad_(False)
144
+ text_encoder.requires_grad_(False)
145
+ unet.requires_grad_(False)
146
+
147
+ unet.to(device)
148
+ vae.to(device)
149
+ text_encoder.to(device)
150
+
151
+ # initialize UNet LoRA
152
+ unet_lora_attn_procs = {}
153
+ for name, attn_processor in unet.attn_processors.items():
154
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
155
+ if name.startswith("mid_block"):
156
+ hidden_size = unet.config.block_out_channels[-1]
157
+ elif name.startswith("up_blocks"):
158
+ block_id = int(name[len("up_blocks.")])
159
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
160
+ elif name.startswith("down_blocks"):
161
+ block_id = int(name[len("down_blocks.")])
162
+ hidden_size = unet.config.block_out_channels[block_id]
163
+ else:
164
+ raise NotImplementedError("name must start with up_blocks, mid_blocks, or down_blocks")
165
+
166
+ if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)):
167
+ lora_attn_processor_class = LoRAAttnAddedKVProcessor
168
+ else:
169
+ lora_attn_processor_class = (
170
+ LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
171
+ )
172
+ unet_lora_attn_procs[name] = lora_attn_processor_class(
173
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
174
+ )
175
+ unet.set_attn_processor(unet_lora_attn_procs)
176
+ unet_lora_layers = AttnProcsLayers(unet.attn_processors)
177
+
178
+ # Optimizer creation
179
+ params_to_optimize = (unet_lora_layers.parameters())
180
+ optimizer = torch.optim.AdamW(
181
+ params_to_optimize,
182
+ lr=lora_lr,
183
+ betas=(0.9, 0.999),
184
+ weight_decay=1e-2,
185
+ eps=1e-08,
186
+ )
187
+
188
+ lr_scheduler = get_scheduler(
189
+ "constant",
190
+ optimizer=optimizer,
191
+ num_warmup_steps=0,
192
+ num_training_steps=lora_steps,
193
+ num_cycles=1,
194
+ power=1.0,
195
+ )
196
+
197
+ # prepare accelerator
198
+ unet_lora_layers = accelerator.prepare_model(unet_lora_layers)
199
+ optimizer = accelerator.prepare_optimizer(optimizer)
200
+ lr_scheduler = accelerator.prepare_scheduler(lr_scheduler)
201
+
202
+ # initialize text embeddings
203
+ with torch.no_grad():
204
+ text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None)
205
+ text_embedding = encode_prompt(
206
+ text_encoder,
207
+ text_inputs.input_ids,
208
+ text_inputs.attention_mask,
209
+ text_encoder_use_attention_mask=False
210
+ )
211
+
212
+ if type(image) == np.ndarray:
213
+ image = Image.fromarray(image)
214
+
215
+ # initialize latent distribution
216
+ image_transforms = transforms.Compose(
217
+ [
218
+ transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
219
+ # transforms.RandomCrop(512),
220
+ transforms.ToTensor(),
221
+ transforms.Normalize([0.5], [0.5]),
222
+ ]
223
+ )
224
+
225
+ image = image_transforms(image).to(device)
226
+ image = image.unsqueeze(dim=0)
227
+
228
+ latents_dist = vae.encode(image).latent_dist
229
+ for _ in progress.tqdm(range(lora_steps), desc="Training LoRA..."):
230
+ unet.train()
231
+ model_input = latents_dist.sample() * vae.config.scaling_factor
232
+ # Sample noise that we'll add to the latents
233
+ noise = torch.randn_like(model_input)
234
+ bsz, channels, height, width = model_input.shape
235
+ # Sample a random timestep for each image
236
+ timesteps = torch.randint(
237
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
238
+ )
239
+ timesteps = timesteps.long()
240
+
241
+ # Add noise to the model input according to the noise magnitude at each timestep
242
+ # (this is the forward diffusion process)
243
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
244
+
245
+ # Predict the noise residual
246
+ model_pred = unet(noisy_model_input, timesteps, text_embedding).sample
247
+
248
+ # Get the target for loss depending on the prediction type
249
+ if noise_scheduler.config.prediction_type == "epsilon":
250
+ target = noise
251
+ elif noise_scheduler.config.prediction_type == "v_prediction":
252
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
253
+ else:
254
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
255
+
256
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
257
+ accelerator.backward(loss)
258
+ optimizer.step()
259
+ lr_scheduler.step()
260
+ optimizer.zero_grad()
261
+
262
+ # save the trained lora
263
+ # unet = unet.to(torch.float32)
264
+ # vae = vae.to(torch.float32)
265
+ # text_encoder = text_encoder.to(torch.float32)
266
+
267
+ # unwrap_model is used to remove all special modules added when doing distributed training
268
+ # so here, there is no need to call unwrap_model
269
+ # unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
270
+ LoraLoaderMixin.save_lora_weights(
271
+ save_directory=save_lora_dir,
272
+ unet_lora_layers=unet_lora_layers,
273
+ text_encoder_lora_layers=None,
274
+ weight_name=weight_name,
275
+ safe_serialization=safe_serialization
276
+ )
277
+
278
+ def load_lora(unet, lora_0, lora_1, alpha):
279
+ lora = {}
280
+ for key in lora_0:
281
+ lora[key] = (1 - alpha) * lora_0[key] + alpha * lora_1[key]
282
+ unet.load_attn_procs(lora)
283
+ return unet
284
+
285
+ # import safetensors
286
+ # unet = UNet2DConditionModel.from_pretrained(
287
+ # "stabilityai/stable-diffusion-2-1-base", subfolder="unet", revision=None
288
+ # )
289
+ # lora = safetensors.torch.load_file("../models/lora/majicmixRealistic_betterV2V25.safetensors", device="cuda")
290
+ # unet = safetensors.torch.load_file("../stabilityai/stable-diffusion-1-5/v1-5-pruned-emaonly.safetensors", device="cuda")
291
+ # with open("lora.txt", "w") as f:
292
+ # for key in lora:
293
+ # f.write(f"{key} {lora[key].shape}\n")
294
+ # with open("unet.txt", "w") as f:
295
+ # for key in unet:
296
+ # f.write(f"{key} {unet[key].shape}\n")
297
+ # unet.load_attn_procs(lora)
298
+
299
+ # lora_path = "models/lora"
300
+ # image_path_1 = "input/sculpture.jpg"
301
+ # # image_path_0 = "input/realdog0.jpg"
302
+
303
+ # prompt = "a photo of a sculpture"
304
+ # train_lora(Image.open(image_path_1), prompt, lora_path, "stabilityai/stable-diffusion-1-5", weight_name="sculpture_v15.safetensors", safe_serialization=True)
305
+ # train_lora(image_path_0, prompt, "stabilityai/stable-diffusion-2-1-base", lora_path, weight_name="realdog0.ckpt")
306
+ # realdog1_lora = torch.load(os.path.join(lora_path, "realdog1.ckpt"))
307
+ # realdog0_lora = torch.load(os.path.join(lora_path, "realdog0.ckpt"))
308
+
309
+ # pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float32)
310
+ # pipe.to("cuda")
311
+
312
+ # for t in torch.linspace(0, 1, 10):
313
+ # lora = {}
314
+ # for key in realdog0_lora:
315
+ # lora[key] = (1 - t) * realdog1_lora[key] + t * realdog0_lora[key]
316
+ # pipe.unet.load_attn_procs(lora)
317
+ # image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
318
+ # image.save(f"test/lora_interp/{t}.jpg")
morph_attn.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
3
+ from diffusers.models.attention_processor import AttnProcessor
4
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
5
+ from diffusers.schedulers import KarrasDiffusionSchedulers
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import tqdm
9
+ import numpy as np
10
+ import safetensors
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
14
+ from lora_utils import train_lora, load_lora
15
+ from diffusers import StableDiffusionPipeline
16
+ from argparse import ArgumentParser
17
+ from alpha_scheduler import AlphaScheduler
18
+
19
+ parser = ArgumentParser()
20
+ parser.add_argument(
21
+ '--image_path_0', type=str, default='',
22
+ help='Path of the image to be processed (default: %(default)s)')
23
+ parser.add_argument(
24
+ '--prompt_0', type=str, default='',
25
+ help='Prompt of the image (default: %(default)s)')
26
+ parser.add_argument(
27
+ '--image_path_1', type=str, default='',
28
+ help='Path of the 2nd image to be processed, used in "morphing" mode (default: %(default)s)')
29
+ parser.add_argument(
30
+ '--prompt_1', type=str, default='',
31
+ help='Prompt of the 2nd image, used in "morphing" mode (default: %(default)s)')
32
+ parser.add_argument(
33
+ '--output_path', type=str, default='',
34
+ help='Path of the output image (default: %(default)s)'
35
+ )
36
+ parser.add_argument(
37
+ '--num_frames', type=int, default=50,
38
+ help='Number of frames to generate (default: %(default)s)'
39
+ )
40
+ parser.add_argument(
41
+ '--duration', type=int, default=50,
42
+ help='Duration of each frame (default: %(default)s)'
43
+ )
44
+ parser.add_argument(
45
+ '--use_lora', action='store_true',
46
+ help='Use LORA to generate images (default: False)'
47
+ )
48
+ parser.add_argument(
49
+ '--guidance_scale', type=float, default=1.,
50
+ help='CFG guidace (default: %(default)s)'
51
+ )
52
+ parser.add_argument(
53
+ '--attn_beta', type=float, default=None,
54
+ )
55
+ parser.add_argument(
56
+ '-reschedule', action='store_true',
57
+ )
58
+ parser.add_argument(
59
+ '--lamd', type=float, default=0.6,
60
+ )
61
+ parser.add_argument(
62
+ '--use_adain', action='store_true'
63
+ )
64
+
65
+ args = parser.parse_args()
66
+ # name = args.output_path.split('/')[-1]
67
+ # attn_beta = args.attn_beta
68
+ # num_frames = args.num_frames
69
+ # use_alpha_scheduler = args.reschedule
70
+ # attn_step = 50 * args.lamd
71
+
72
+
73
+ def calc_mean_std(feat, eps=1e-5):
74
+ # eps is a small value added to the variance to avoid divide-by-zero.
75
+ size = feat.size()
76
+
77
+ N, C = size[:2]
78
+ feat_var = feat.view(N, C, -1).var(dim=2) + eps
79
+ if len(size) == 3:
80
+ feat_std = feat_var.sqrt().view(N, C, 1)
81
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1)
82
+ else:
83
+ feat_std = feat_var.sqrt().view(N, C, 1, 1)
84
+ feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
85
+ return feat_mean, feat_std
86
+
87
+
88
+ def get_img(img, resolution=512):
89
+ norm_mean = [0.5, 0.5, 0.5]
90
+ norm_std = [0.5, 0.5, 0.5]
91
+ transform = transforms.Compose([
92
+ transforms.Resize((resolution, resolution)),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize(norm_mean, norm_std)
95
+ ])
96
+ img = transform(img)
97
+ return img.unsqueeze(0)
98
+
99
+ @torch.no_grad()
100
+ def slerp(p0, p1, fract_mixing: float, adain=True):
101
+ r""" Copied from lunarring/latentblending
102
+ Helper function to correctly mix two random variables using spherical interpolation.
103
+ The function will always cast up to float64 for sake of extra 4.
104
+ Args:
105
+ p0:
106
+ First tensor for interpolation
107
+ p1:
108
+ Second tensor for interpolation
109
+ fract_mixing: float
110
+ Mixing coefficient of interval [0, 1].
111
+ 0 will return in p0
112
+ 1 will return in p1
113
+ 0.x will return a mix between both preserving angular velocity.
114
+ """
115
+ if p0.dtype == torch.float16:
116
+ recast_to = 'fp16'
117
+ else:
118
+ recast_to = 'fp32'
119
+
120
+ p0 = p0.double()
121
+ p1 = p1.double()
122
+
123
+ if adain:
124
+ mean1, std1 = calc_mean_std(p0)
125
+ mean2, std2 = calc_mean_std(p1)
126
+ mean = mean1 * (1 - fract_mixing) + mean2 * fract_mixing
127
+ std = std1 * (1 - fract_mixing) + std2 * fract_mixing
128
+
129
+ norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
130
+ epsilon = 1e-7
131
+ dot = torch.sum(p0 * p1) / norm
132
+ dot = dot.clamp(-1+epsilon, 1-epsilon)
133
+
134
+ theta_0 = torch.arccos(dot)
135
+ sin_theta_0 = torch.sin(theta_0)
136
+ theta_t = theta_0 * fract_mixing
137
+ s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
138
+ s1 = torch.sin(theta_t) / sin_theta_0
139
+ interp = p0*s0 + p1*s1
140
+
141
+ if adain:
142
+ interp = F.instance_norm(interp) * std + mean
143
+
144
+ if recast_to == 'fp16':
145
+ interp = interp.half()
146
+ elif recast_to == 'fp32':
147
+ interp = interp.float()
148
+
149
+ return interp
150
+
151
+
152
+ def do_replace_attn(key: str):
153
+ # return key.startswith('up_blocks.2') or key.startswith('up_blocks.3')
154
+ return key.startswith('up')
155
+
156
+
157
+ class StoreProcessor():
158
+ def __init__(self, original_processor, value_dict, name):
159
+ self.original_processor = original_processor
160
+ self.value_dict = value_dict
161
+ self.name = name
162
+ self.value_dict[self.name] = dict()
163
+ self.id = 0
164
+
165
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
166
+ # Is self attention
167
+ if encoder_hidden_states is None:
168
+ self.value_dict[self.name][self.id] = hidden_states.detach()
169
+ self.id += 1
170
+ res = self.original_processor(attn, hidden_states, *args,
171
+ encoder_hidden_states=encoder_hidden_states,
172
+ attention_mask=attention_mask,
173
+ **kwargs)
174
+
175
+ return res
176
+
177
+
178
+ class LoadProcessor():
179
+ def __init__(self, original_processor, name, img0_dict, img1_dict, alpha, beta=0, lamb=0.6):
180
+ super().__init__()
181
+ self.original_processor = original_processor
182
+ self.name = name
183
+ self.img0_dict = img0_dict
184
+ self.img1_dict = img1_dict
185
+ self.alpha = alpha
186
+ self.beta = beta
187
+ self.lamb = lamb
188
+ self.id = 0
189
+
190
+ def parent_call(
191
+ self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
192
+ ):
193
+ residual = hidden_states
194
+
195
+ if attn.spatial_norm is not None:
196
+ hidden_states = attn.spatial_norm(hidden_states)
197
+
198
+ input_ndim = hidden_states.ndim
199
+
200
+ if input_ndim == 4:
201
+ batch_size, channel, height, width = hidden_states.shape
202
+ hidden_states = hidden_states.view(
203
+ batch_size, channel, height * width).transpose(1, 2)
204
+
205
+ batch_size, sequence_length, _ = (
206
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
207
+ )
208
+ attention_mask = attn.prepare_attention_mask(
209
+ attention_mask, sequence_length, batch_size)
210
+
211
+ if attn.group_norm is not None:
212
+ hidden_states = attn.group_norm(
213
+ hidden_states.transpose(1, 2)).transpose(1, 2)
214
+
215
+ query = attn.to_q(hidden_states) + scale * \
216
+ self.original_processor.to_q_lora(hidden_states)
217
+ query = attn.head_to_batch_dim(query)
218
+
219
+ if encoder_hidden_states is None:
220
+ encoder_hidden_states = hidden_states
221
+ elif attn.norm_cross:
222
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
223
+ encoder_hidden_states)
224
+
225
+ key = attn.to_k(encoder_hidden_states) + scale * \
226
+ self.original_processor.to_k_lora(encoder_hidden_states)
227
+ value = attn.to_v(encoder_hidden_states) + scale * \
228
+ self.original_processor.to_v_lora(encoder_hidden_states)
229
+
230
+ key = attn.head_to_batch_dim(key)
231
+ value = attn.head_to_batch_dim(value)
232
+
233
+ attention_probs = attn.get_attention_scores(
234
+ query, key, attention_mask)
235
+ hidden_states = torch.bmm(attention_probs, value)
236
+ hidden_states = attn.batch_to_head_dim(hidden_states)
237
+
238
+ # linear proj
239
+ hidden_states = attn.to_out[0](
240
+ hidden_states) + scale * self.original_processor.to_out_lora(hidden_states)
241
+ # dropout
242
+ hidden_states = attn.to_out[1](hidden_states)
243
+
244
+ if input_ndim == 4:
245
+ hidden_states = hidden_states.transpose(
246
+ -1, -2).reshape(batch_size, channel, height, width)
247
+
248
+ if attn.residual_connection:
249
+ hidden_states = hidden_states + residual
250
+
251
+ hidden_states = hidden_states / attn.rescale_output_factor
252
+
253
+ return hidden_states
254
+
255
+ def __call__(self, attn, hidden_states, *args, encoder_hidden_states=None, attention_mask=None, **kwargs):
256
+ # Is self attention
257
+ if encoder_hidden_states is None:
258
+ # hardcode timestep
259
+ if self.id < 50 * self.lamb:
260
+ map0 = self.img0_dict[self.name][self.id]
261
+ map1 = self.img1_dict[self.name][self.id]
262
+ cross_map = self.beta * hidden_states + \
263
+ (1 - self.beta) * ((1 - self.alpha) * map0 + self.alpha * map1)
264
+ # cross_map = self.beta * hidden_states + \
265
+ # (1 - self.beta) * slerp(map0, map1, self.alpha)
266
+ # cross_map = slerp(slerp(map0, map1, self.alpha),
267
+ # hidden_states, self.beta)
268
+ # cross_map = hidden_states
269
+ # cross_map = torch.cat(
270
+ # ((1 - self.alpha) * map0, self.alpha * map1), dim=1)
271
+
272
+ # res = self.original_processor(attn, hidden_states, *args,
273
+ # encoder_hidden_states=cross_map,
274
+ # attention_mask=attention_mask,
275
+ # temb=temb, **kwargs)
276
+ res = self.parent_call(attn, hidden_states, *args,
277
+ encoder_hidden_states=cross_map,
278
+ attention_mask=attention_mask,
279
+ **kwargs)
280
+ else:
281
+ res = self.original_processor(attn, hidden_states, *args,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ attention_mask=attention_mask,
284
+ **kwargs)
285
+
286
+ self.id += 1
287
+ # if self.id == len(self.img0_dict[self.name]):
288
+ if self.id == len(self.img0_dict[self.name]):
289
+ self.id = 0
290
+ else:
291
+ res = self.original_processor(attn, hidden_states, *args,
292
+ encoder_hidden_states=encoder_hidden_states,
293
+ attention_mask=attention_mask,
294
+ **kwargs)
295
+
296
+ return res
297
+
298
+
299
+ class DiffMorpherPipeline(StableDiffusionPipeline):
300
+
301
+ def __init__(self,
302
+ vae: AutoencoderKL,
303
+ text_encoder: CLIPTextModel,
304
+ tokenizer: CLIPTokenizer,
305
+ unet: UNet2DConditionModel,
306
+ scheduler: KarrasDiffusionSchedulers,
307
+ safety_checker: StableDiffusionSafetyChecker,
308
+ feature_extractor: CLIPImageProcessor,
309
+ requires_safety_checker: bool = True,
310
+ ):
311
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
312
+ safety_checker, feature_extractor, requires_safety_checker)
313
+ self.img0_dict = dict()
314
+ self.img1_dict = dict()
315
+
316
+ def inv_step(
317
+ self,
318
+ model_output: torch.FloatTensor,
319
+ timestep: int,
320
+ x: torch.FloatTensor,
321
+ eta=0.,
322
+ verbose=False
323
+ ):
324
+ """
325
+ Inverse sampling for DDIM Inversion
326
+ """
327
+ if verbose:
328
+ print("timestep: ", timestep)
329
+ next_step = timestep
330
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps //
331
+ self.scheduler.num_inference_steps, 999)
332
+ alpha_prod_t = self.scheduler.alphas_cumprod[
333
+ timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
334
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
335
+ beta_prod_t = 1 - alpha_prod_t
336
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
337
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
338
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
339
+ return x_next, pred_x0
340
+
341
+ @torch.no_grad()
342
+ def invert(
343
+ self,
344
+ image: torch.Tensor,
345
+ prompt,
346
+ num_inference_steps=50,
347
+ num_actual_inference_steps=None,
348
+ guidance_scale=1.,
349
+ eta=0.0,
350
+ **kwds):
351
+ """
352
+ invert a real image into noise map with determinisc DDIM inversion
353
+ """
354
+ DEVICE = torch.device(
355
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
356
+ batch_size = image.shape[0]
357
+ if isinstance(prompt, list):
358
+ if batch_size == 1:
359
+ image = image.expand(len(prompt), -1, -1, -1)
360
+ elif isinstance(prompt, str):
361
+ if batch_size > 1:
362
+ prompt = [prompt] * batch_size
363
+
364
+ # text embeddings
365
+ text_input = self.tokenizer(
366
+ prompt,
367
+ padding="max_length",
368
+ max_length=77,
369
+ return_tensors="pt"
370
+ )
371
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
372
+ print("input text embeddings :", text_embeddings.shape)
373
+ # define initial latents
374
+ latents = self.image2latent(image)
375
+
376
+ # unconditional embedding for classifier free guidance
377
+ if guidance_scale > 1.:
378
+ max_length = text_input.input_ids.shape[-1]
379
+ unconditional_input = self.tokenizer(
380
+ [""] * batch_size,
381
+ padding="max_length",
382
+ max_length=77,
383
+ return_tensors="pt"
384
+ )
385
+ unconditional_embeddings = self.text_encoder(
386
+ unconditional_input.input_ids.to(DEVICE))[0]
387
+ text_embeddings = torch.cat(
388
+ [unconditional_embeddings, text_embeddings], dim=0)
389
+
390
+ print("latents shape: ", latents.shape)
391
+ # interative sampling
392
+ self.scheduler.set_timesteps(num_inference_steps)
393
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
394
+ # print("attributes: ", self.scheduler.__dict__)
395
+ latents_list = [latents]
396
+ pred_x0_list = [latents]
397
+ for i, t in enumerate(tqdm.tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
398
+ if num_actual_inference_steps is not None and i >= num_actual_inference_steps:
399
+ continue
400
+
401
+ if guidance_scale > 1.:
402
+ model_inputs = torch.cat([latents] * 2)
403
+ else:
404
+ model_inputs = latents
405
+
406
+ # predict the noise
407
+ noise_pred = self.unet(
408
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
409
+ if guidance_scale > 1.:
410
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
411
+ noise_pred = noise_pred_uncon + guidance_scale * \
412
+ (noise_pred_con - noise_pred_uncon)
413
+ # compute the previous noise sample x_t-1 -> x_t
414
+ latents, pred_x0 = self.inv_step(noise_pred, t, latents)
415
+ latents_list.append(latents)
416
+ pred_x0_list.append(pred_x0)
417
+
418
+ return latents
419
+
420
+ @torch.no_grad()
421
+ def ddim_inversion(self, latent, cond):
422
+ timesteps = reversed(self.scheduler.timesteps)
423
+ with torch.autocast(device_type='cuda', dtype=torch.float32):
424
+ for i, t in enumerate(tqdm.tqdm(timesteps, desc="DDIM inversion")):
425
+ cond_batch = cond.repeat(latent.shape[0], 1, 1)
426
+
427
+ alpha_prod_t = self.scheduler.alphas_cumprod[t]
428
+ alpha_prod_t_prev = (
429
+ self.scheduler.alphas_cumprod[timesteps[i - 1]]
430
+ if i > 0 else self.scheduler.final_alpha_cumprod
431
+ )
432
+
433
+ mu = alpha_prod_t ** 0.5
434
+ mu_prev = alpha_prod_t_prev ** 0.5
435
+ sigma = (1 - alpha_prod_t) ** 0.5
436
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
437
+
438
+ eps = self.unet(
439
+ latent, t, encoder_hidden_states=cond_batch).sample
440
+
441
+ pred_x0 = (latent - sigma_prev * eps) / mu_prev
442
+ latent = mu * pred_x0 + sigma * eps
443
+ # if save_latents:
444
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
445
+ # torch.save(latent, os.path.join(save_path, f'noisy_latents_{t}.pt'))
446
+ return latent
447
+
448
+ def step(
449
+ self,
450
+ model_output: torch.FloatTensor,
451
+ timestep: int,
452
+ x: torch.FloatTensor,
453
+ ):
454
+ """
455
+ predict the sample of the next step in the denoise process.
456
+ """
457
+ prev_timestep = timestep - \
458
+ self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
459
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
460
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[
461
+ prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
462
+ beta_prod_t = 1 - alpha_prod_t
463
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
464
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
465
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
466
+ return x_prev, pred_x0
467
+
468
+ @torch.no_grad()
469
+ def image2latent(self, image):
470
+ DEVICE = torch.device(
471
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
472
+ if type(image) is Image:
473
+ image = np.array(image)
474
+ image = torch.from_numpy(image).float() / 127.5 - 1
475
+ image = image.permute(2, 0, 1).unsqueeze(0)
476
+ # input image density range [-1, 1]
477
+ latents = self.vae.encode(image.to(DEVICE))['latent_dist'].mean
478
+ latents = latents * 0.18215
479
+ return latents
480
+
481
+ @torch.no_grad()
482
+ def latent2image(self, latents, return_type='np'):
483
+ latents = 1 / 0.18215 * latents.detach()
484
+ image = self.vae.decode(latents)['sample']
485
+ if return_type == 'np':
486
+ image = (image / 2 + 0.5).clamp(0, 1)
487
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
488
+ image = (image * 255).astype(np.uint8)
489
+ elif return_type == "pt":
490
+ image = (image / 2 + 0.5).clamp(0, 1)
491
+
492
+ return image
493
+
494
+ def latent2image_grad(self, latents):
495
+ latents = 1 / 0.18215 * latents
496
+ image = self.vae.decode(latents)['sample']
497
+
498
+ return image # range [-1, 1]
499
+
500
+ @torch.no_grad()
501
+ def cal_latent(self, num_inference_steps, guidance_scale, unconditioning, img_noise_0, img_noise_1, text_embeddings_0, text_embeddings_1, lora_0, lora_1, alpha, use_lora, fix_lora=None):
502
+ # latents = torch.cos(alpha * torch.pi / 2) * img_noise_0 + \
503
+ # torch.sin(alpha * torch.pi / 2) * img_noise_1
504
+ # latents = (1 - alpha) * img_noise_0 + alpha * img_noise_1
505
+ # latents = latents / ((1 - alpha) ** 2 + alpha ** 2)
506
+ latents = slerp(img_noise_0, img_noise_1, alpha, self.use_adain)
507
+ text_embeddings = (1 - alpha) * text_embeddings_0 + \
508
+ alpha * text_embeddings_1
509
+
510
+ self.scheduler.set_timesteps(num_inference_steps)
511
+ if use_lora:
512
+ if fix_lora is not None:
513
+ self.unet = load_lora(self.unet, lora_0, lora_1, fix_lora)
514
+ else:
515
+ self.unet = load_lora(self.unet, lora_0, lora_1, alpha)
516
+
517
+ for i, t in enumerate(tqdm.tqdm(self.scheduler.timesteps, desc=f"DDIM Sampler, alpha={alpha}")):
518
+
519
+ if guidance_scale > 1.:
520
+ model_inputs = torch.cat([latents] * 2)
521
+ else:
522
+ model_inputs = latents
523
+ if unconditioning is not None and isinstance(unconditioning, list):
524
+ _, text_embeddings = text_embeddings.chunk(2)
525
+ text_embeddings = torch.cat(
526
+ [unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
527
+ # predict the noise
528
+ noise_pred = self.unet(
529
+ model_inputs, t, encoder_hidden_states=text_embeddings).sample
530
+ if guidance_scale > 1.0:
531
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(
532
+ 2, dim=0)
533
+ noise_pred = noise_pred_uncon + guidance_scale * \
534
+ (noise_pred_con - noise_pred_uncon)
535
+ # compute the previous noise sample x_t -> x_t-1
536
+ # YUJUN: right now, the only difference between step here and step in scheduler
537
+ # is that scheduler version would clamp pred_x0 between [-1,1]
538
+ # don't know if that's gonna have huge impact
539
+ latents = self.scheduler.step(
540
+ noise_pred, t, latents, return_dict=False)[0]
541
+ return latents
542
+
543
+ @torch.no_grad()
544
+ def get_text_embeddings(self, prompt, guidance_scale, neg_prompt, batch_size):
545
+ DEVICE = torch.device(
546
+ "cuda") if torch.cuda.is_available() else torch.device("cpu")
547
+ # text embeddings
548
+ text_input = self.tokenizer(
549
+ prompt,
550
+ padding="max_length",
551
+ max_length=77,
552
+ return_tensors="pt"
553
+ )
554
+ text_embeddings = self.text_encoder(text_input.input_ids.cuda())[0]
555
+
556
+ if guidance_scale > 1.:
557
+ if neg_prompt:
558
+ uc_text = neg_prompt
559
+ else:
560
+ uc_text = ""
561
+ unconditional_input = self.tokenizer(
562
+ [uc_text] * batch_size,
563
+ padding="max_length",
564
+ max_length=77,
565
+ return_tensors="pt"
566
+ )
567
+ unconditional_embeddings = self.text_encoder(
568
+ unconditional_input.input_ids.to(DEVICE))[0]
569
+ text_embeddings = torch.cat(
570
+ [unconditional_embeddings, text_embeddings], dim=0)
571
+
572
+ return text_embeddings
573
+
574
+ def __call__(
575
+ self,
576
+ img_0=None,
577
+ img_1=None,
578
+ img_path_0=None,
579
+ img_path_1=None,
580
+ prompt_0="",
581
+ prompt_1="",
582
+ save_lora_dir="./lora",
583
+ load_lora_path_0=None,
584
+ load_lora_path_1=None,
585
+ lora_steps=200,
586
+ lora_lr=2e-4,
587
+ lora_rank=16,
588
+ batch_size=1,
589
+ height=512,
590
+ width=512,
591
+ num_inference_steps=50,
592
+ num_actual_inference_steps=None,
593
+ guidance_scale=1,
594
+ attn_beta=0,
595
+ lamb=0.6,
596
+ use_lora = True,
597
+ use_adain = True,
598
+ use_reschedule = True,
599
+ output_path = "./results",
600
+ num_frames=50,
601
+ fix_lora=None,
602
+ progress=tqdm,
603
+ unconditioning=None,
604
+ neg_prompt=None,
605
+ **kwds):
606
+
607
+ # if isinstance(prompt, list):
608
+ # batch_size = len(prompt)
609
+ # elif isinstance(prompt, str):
610
+ # if batch_size > 1:
611
+ # prompt = [prompt] * batch_size
612
+ self.scheduler.set_timesteps(num_inference_steps)
613
+ self.use_lora = use_lora
614
+ self.use_adain = use_adain
615
+ self.use_reschedule = use_reschedule
616
+ self.output_path = output_path
617
+
618
+ if img_0 is None:
619
+ img_0 = Image.open(img_path_0).convert("RGB")
620
+ # else:
621
+ # img_0 = Image.fromarray(img_0).convert("RGB")
622
+
623
+ if img_1 is None:
624
+ img_1 = Image.open(img_path_1).convert("RGB")
625
+ # else:
626
+ # img_1 = Image.fromarray(img_1).convert("RGB")
627
+ if self.use_lora:
628
+ print("Loading lora...")
629
+ if not load_lora_path_0:
630
+
631
+ weight_name = f"{output_path.split('/')[-1]}_lora_0.ckpt"
632
+ load_lora_path_0 = save_lora_dir + "/" + weight_name
633
+ if not os.path.exists(load_lora_path_0):
634
+ train_lora(img_0, prompt_0, save_lora_dir, None, self.tokenizer, self.text_encoder,
635
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
636
+ print(f"Load from {load_lora_path_0}.")
637
+ if load_lora_path_0.endswith(".safetensors"):
638
+ lora_0 = safetensors.torch.load_file(
639
+ load_lora_path_0, device="cpu")
640
+ else:
641
+ lora_0 = torch.load(load_lora_path_0, map_location="cpu")
642
+
643
+ if not load_lora_path_1:
644
+ weight_name = f"{output_path.split('/')[-1]}_lora_1.ckpt"
645
+ load_lora_path_1 = save_lora_dir + "/" + weight_name
646
+ if not os.path.exists(load_lora_path_1):
647
+ train_lora(img_1, prompt_1, save_lora_dir, None, self.tokenizer, self.text_encoder,
648
+ self.vae, self.unet, self.scheduler, lora_steps, lora_lr, lora_rank, weight_name=weight_name)
649
+ print(f"Load from {load_lora_path_1}.")
650
+ if load_lora_path_1.endswith(".safetensors"):
651
+ lora_1 = safetensors.torch.load_file(
652
+ load_lora_path_1, device="cpu")
653
+ else:
654
+ lora_1 = torch.load(load_lora_path_1, map_location="cpu")
655
+
656
+ text_embeddings_0 = self.get_text_embeddings(
657
+ prompt_0, guidance_scale, neg_prompt, batch_size)
658
+ text_embeddings_1 = self.get_text_embeddings(
659
+ prompt_1, guidance_scale, neg_prompt, batch_size)
660
+ img_0 = get_img(img_0)
661
+ img_1 = get_img(img_1)
662
+ if self.use_lora:
663
+ self.unet = load_lora(self.unet, lora_0, lora_1, 0)
664
+ img_noise_0 = self.ddim_inversion(
665
+ self.image2latent(img_0), text_embeddings_0)
666
+ if self.use_lora:
667
+ self.unet = load_lora(self.unet, lora_0, lora_1, 1)
668
+ img_noise_1 = self.ddim_inversion(
669
+ self.image2latent(img_1), text_embeddings_1)
670
+
671
+ print("latents shape: ", img_noise_0.shape)
672
+
673
+ def morph(alpha_list, progress, desc, save=False):
674
+ images = []
675
+ if attn_beta is not None:
676
+
677
+ self.unet = load_lora(self.unet, lora_0, lora_1, 0 if fix_lora is None else fix_lora)
678
+ attn_processor_dict = {}
679
+ for k in self.unet.attn_processors.keys():
680
+ if do_replace_attn(k):
681
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
682
+ self.img0_dict, k)
683
+ else:
684
+ attn_processor_dict[k] = self.unet.attn_processors[k]
685
+ self.unet.set_attn_processor(attn_processor_dict)
686
+
687
+ latents = self.cal_latent(
688
+ num_inference_steps,
689
+ guidance_scale,
690
+ unconditioning,
691
+ img_noise_0,
692
+ img_noise_1,
693
+ text_embeddings_0,
694
+ text_embeddings_1,
695
+ lora_0,
696
+ lora_1,
697
+ alpha_list[0],
698
+ False,
699
+ fix_lora
700
+ )
701
+ first_image = self.latent2image(latents)
702
+ first_image = Image.fromarray(first_image)
703
+ if save:
704
+ first_image.save(f"{self.output_path}/{0:02d}.png")
705
+
706
+ self.unet = load_lora(self.unet, lora_0, lora_1, 1 if fix_lora is None else fix_lora)
707
+ attn_processor_dict = {}
708
+ for k in self.unet.attn_processors.keys():
709
+ if do_replace_attn(k):
710
+ attn_processor_dict[k] = StoreProcessor(self.unet.attn_processors[k],
711
+ self.img1_dict, k)
712
+ else:
713
+ attn_processor_dict[k] = self.unet.attn_processors[k]
714
+
715
+ self.unet.set_attn_processor(attn_processor_dict)
716
+
717
+ latents = self.cal_latent(
718
+ num_inference_steps,
719
+ guidance_scale,
720
+ unconditioning,
721
+ img_noise_0,
722
+ img_noise_1,
723
+ text_embeddings_0,
724
+ text_embeddings_1,
725
+ lora_0,
726
+ lora_1,
727
+ alpha_list[-1],
728
+ False,
729
+ fix_lora
730
+ )
731
+ last_image = self.latent2image(latents)
732
+ last_image = Image.fromarray(last_image)
733
+ if save:
734
+ last_image.save(
735
+ f"{self.output_path}/{num_frames - 1:02d}.png")
736
+
737
+ for i in progress.tqdm(range(1, num_frames - 1), desc=desc):
738
+ alpha = alpha_list[i]
739
+ self.unet = load_lora(self.unet, lora_0, lora_1, alpha if fix_lora is None else fix_lora)
740
+ attn_processor_dict = {}
741
+ for k in self.unet.attn_processors.keys():
742
+ if do_replace_attn(k):
743
+ attn_processor_dict[k] = LoadProcessor(
744
+ self.unet.attn_processors[k], k, self.img0_dict, self.img1_dict, alpha, attn_beta, lamb)
745
+ else:
746
+ attn_processor_dict[k] = self.unet.attn_processors[k]
747
+
748
+ self.unet.set_attn_processor(attn_processor_dict)
749
+
750
+ latents = self.cal_latent(
751
+ num_inference_steps,
752
+ guidance_scale,
753
+ unconditioning,
754
+ img_noise_0,
755
+ img_noise_1,
756
+ text_embeddings_0,
757
+ text_embeddings_1,
758
+ lora_0,
759
+ lora_1,
760
+ alpha_list[i],
761
+ False,
762
+ fix_lora
763
+ )
764
+ image = self.latent2image(latents)
765
+ image = Image.fromarray(image)
766
+ if save:
767
+ image.save(f"{self.output_path}/{i:02d}.png")
768
+ images.append(image)
769
+
770
+ images = [first_image] + images + [last_image]
771
+
772
+ else:
773
+ for k, alpha in enumerate(alpha_list):
774
+
775
+ latents = self.cal_latent(
776
+ num_inference_steps,
777
+ guidance_scale,
778
+ unconditioning,
779
+ img_noise_0,
780
+ img_noise_1,
781
+ text_embeddings_0,
782
+ text_embeddings_1,
783
+ lora_0,
784
+ lora_1,
785
+ alpha_list[k],
786
+ self.use_lora,
787
+ fix_lora
788
+ )
789
+ image = self.latent2image(latents)
790
+ image = Image.fromarray(image)
791
+ if save:
792
+ image.save(f"{self.output_path}/{k:02d}.png")
793
+ images.append(image)
794
+
795
+ return images
796
+
797
+ with torch.no_grad():
798
+ if self.use_reschedule:
799
+ alpha_scheduler = AlphaScheduler()
800
+ alpha_list = list(torch.linspace(0, 1, num_frames))
801
+ images_pt = morph(alpha_list, progress, "Sampling...", False)
802
+ images_pt = [transforms.ToTensor()(img).unsqueeze(0)
803
+ for img in images_pt]
804
+ alpha_scheduler.from_imgs(images_pt)
805
+ alpha_list = alpha_scheduler.get_list()
806
+ print(alpha_list)
807
+ images = morph(alpha_list, progress, "Reschedule...", False)
808
+ else:
809
+ alpha_list = list(torch.linspace(0, 1, num_frames))
810
+ print(alpha_list)
811
+ images = morph(alpha_list, progress, "Sampling...", False)
812
+
813
+ return images
814
+
815
+
816
+ # os.makedirs(self.output_path, exist_ok=True)
817
+ # pipeline = DiffMorpherPipeline.from_pretrained(
818
+ # "./stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float32)
819
+ # pipeline.to("cuda")
820
+ # images = pipeline(
821
+ # args.image_path_0,
822
+ # args.image_path_1,
823
+ # args.prompt_0,
824
+ # args.prompt_1
825
+ # )
826
+ # images[0].save(f"{self.output_path}/output.gif", save_all=True,
827
+ # append_images=images[1:], duration=args.duration, loop=0)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ diffusers==0.17.1
3
+ einops==0.7.0
4
+ # gradio==4.7.1
5
+ numpy==1.26.1
6
+ opencv_python==4.5.5.64
7
+ packaging==23.2
8
+ Pillow==10.1.0
9
+ safetensors==0.4.0
10
+ torch
11
+ torchvision
12
+ tqdm==4.65.0
13
+ transformers==4.34.1
14
+ lpips