afeng commited on
Commit
8fa9206
1 Parent(s): 0044ccd
app.py CHANGED
@@ -218,7 +218,7 @@ with gr.Blocks() as demo:
218
  with gr.Tab(label="1 Edit mask"):
219
  with gr.Row():
220
  with gr.Column():
221
- canvas = gr.Image(value = None, type="numpy", tool="sketch", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
222
  input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
223
 
224
  segment_button = gr.Button("1.1 Run segmentation")
@@ -283,7 +283,7 @@ with gr.Blocks() as demo:
283
  with gr.Tab(label="2 Optimization"):
284
  with gr.Row():
285
  with gr.Column():
286
- canvas_opt = gr.Image(value = canvas.value, type="pil", tool="sketch", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
287
 
288
  with gr.Column():
289
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
 
218
  with gr.Tab(label="1 Edit mask"):
219
  with gr.Row():
220
  with gr.Column():
221
+ canvas = gr.Image(value = None, type="numpy", label="Draw Mask", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
222
  input_folder = gr.Textbox(value="example1", label="input folder", interactive= True, )
223
 
224
  segment_button = gr.Button("1.1 Run segmentation")
 
283
  with gr.Tab(label="2 Optimization"):
284
  with gr.Row():
285
  with gr.Column():
286
+ canvas_opt = gr.Image(value = canvas.value, type="pil", label="Loaded Image", show_label=True, height=LENGTH, width=LENGTH, interactive=True)
287
 
288
  with gr.Column():
289
  gr.Markdown("""<p style="text-align: center; font-size: 20px">Optimization settings (SD)</p>""")
assets/demo1.gif DELETED
Binary file (724 kB)
 
assets/demo2.gif DELETED
Binary file (941 kB)
 
assets/demo3.gif DELETED
Binary file (761 kB)
 
assets/demo4.gif DELETED
Binary file (530 kB)
 
assets/mask_def.png DELETED
Binary file (41.5 kB)
 
example2/img.png DELETED
Binary file (956 kB)
 
main.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  import numpy as np
4
  import argparse
5
  from peft import LoraConfig
6
- from pipeline_dedit_sdxl import DEditSDXLPipeline
7
  from pipeline_dedit_sd import DEditSDPipeline
8
  from utils import load_image, load_mask, load_mask_edit
9
  from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
 
3
  import numpy as np
4
  import argparse
5
  from peft import LoraConfig
6
+ from old.pipeline_dedit_sdxl import DEditSDXLPipeline
7
  from pipeline_dedit_sd import DEditSDPipeline
8
  from utils import load_image, load_mask, load_mask_edit
9
  from utils_mask import process_mask_move_torch, process_mask_remove_torch, mask_union_torch, mask_substract_torch, create_outer_edge_mask_torch
pipeline_dedit_sdxl.py DELETED
@@ -1,875 +0,0 @@
1
- import torch
2
- from utils import import_model_class_from_model_name_or_path
3
- from transformers import AutoTokenizer
4
- from diffusers import (
5
- AutoencoderKL,
6
- DDPMScheduler,
7
- StableDiffusionXLPipeline,
8
- UNet2DConditionModel,
9
- )
10
- from accelerate import Accelerator
11
- from tqdm.auto import tqdm
12
- from utils import sdxl_prepare_input_decom, save_images
13
- import torch.nn.functional as F
14
- import itertools
15
- from peft import LoraConfig
16
- from controller import GroupedCAController, register_attention_disentangled_control, DummyController
17
- from utils import image2latent, latent2image
18
- import matplotlib.pyplot as plt
19
- from utils_mask import check_mask_overlap_torch
20
-
21
- device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
22
- max_length = 40
23
- class DEditSDXLPipeline:
24
- def __init__(
25
- self,
26
- mask_list,
27
- mask_label_list,
28
- mask_list_2 = None,
29
- mask_label_list_2 = None,
30
- resolution = 1024,
31
- num_tokens = 1
32
- ):
33
- super().__init__()
34
- model_id = "stabilityai/stable-diffusion-xl-base-1.0"
35
- self.model_id = model_id
36
- self.tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer", use_fast=False)
37
- self.tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", use_fast=False)
38
- text_encoder_cls_one = import_model_class_from_model_name_or_path(model_id, subfolder = "text_encoder")
39
- text_encoder_cls_two = import_model_class_from_model_name_or_path(model_id, subfolder="text_encoder_2")
40
- self.text_encoder = text_encoder_cls_one.from_pretrained(model_id, subfolder="text_encoder" ).to(device)
41
- self.text_encoder_2 = text_encoder_cls_two.from_pretrained(model_id, subfolder="text_encoder_2").to(device)
42
- self.unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet" )
43
- self.unet.ca_dim = 2048
44
- self.vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix")
45
- self.scheduler = DDPMScheduler.from_pretrained(model_id , subfolder="scheduler")
46
-
47
- self.mixed_precision = "fp16"
48
- self.resolution = resolution
49
- self.num_tokens = num_tokens
50
-
51
- self.mask_list = mask_list
52
- self.mask_label_list = mask_label_list
53
- notation_token_list = [phrase.split(" ")[-1] for phrase in mask_label_list]
54
- placeholder_token_list = ["#"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list)]
55
- self.set_string_list, placeholder_token_ids = self.add_tokens(placeholder_token_list)
56
- self.min_added_id = min(placeholder_token_ids)
57
- self.max_added_id = max(placeholder_token_ids)
58
-
59
- if mask_list_2 is not None:
60
- self.mask_list_2 = mask_list_2
61
- self.mask_label_list_2 = mask_label_list_2
62
- notation_token_list_2 = [phrase.split(" ")[-1] for phrase in mask_label_list_2]
63
-
64
- placeholder_token_list_2 = ["$"+word+"{}".format(widx) for widx, word in enumerate(notation_token_list_2)]
65
- self.set_string_list_2, placeholder_token_ids_2 = self.add_tokens(placeholder_token_list_2)
66
- self.max_added_id = max(placeholder_token_ids_2)
67
-
68
- def add_tokens_text_encoder_random_init(self, placeholder_token, num_tokens=1):
69
- # Add the placeholder token in tokenizer
70
- placeholder_tokens = [placeholder_token]
71
- # add dummy tokens for multi-vector
72
- additional_tokens = []
73
- for i in range(1, num_tokens):
74
- additional_tokens.append(f"{placeholder_token}_{i}")
75
- placeholder_tokens += additional_tokens
76
- num_added_tokens = self.tokenizer.add_tokens(placeholder_tokens) # 49408
77
- num_added_tokens = self.tokenizer_2.add_tokens(placeholder_tokens) # 49408
78
-
79
- if num_added_tokens != num_tokens:
80
- raise ValueError(
81
- f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
82
- " `placeholder_token` that is not already in the tokenizer."
83
- )
84
- placeholder_token_ids = self.tokenizer.convert_tokens_to_ids(placeholder_tokens)
85
- placeholder_token_ids_2 = self.tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
86
- assert placeholder_token_ids == placeholder_token_ids_2, "Two text encoders are expected to have same vocabs"
87
-
88
- self.text_encoder.resize_token_embeddings(len(self.tokenizer))
89
- token_embeds = self.text_encoder.get_input_embeddings().weight.data
90
- std, mean = torch.std_mean(token_embeds)
91
- with torch.no_grad():
92
- for token_id in placeholder_token_ids:
93
- token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
94
-
95
- self.text_encoder_2.resize_token_embeddings(len(self.tokenizer))
96
- token_embeds = self.text_encoder_2.get_input_embeddings().weight.data
97
- std, mean = torch.std_mean(token_embeds)
98
- with torch.no_grad():
99
- for token_id in placeholder_token_ids:
100
- token_embeds[token_id] = torch.randn_like(token_embeds[token_id])*std + mean
101
-
102
- set_string = " ".join(self.tokenizer.convert_ids_to_tokens(placeholder_token_ids))
103
-
104
- return set_string, placeholder_token_ids
105
-
106
- def add_tokens(self, placeholder_token_list):
107
- set_string_list = []
108
- placeholder_token_ids_list = []
109
- for str_idx in range(len(placeholder_token_list)):
110
- placeholder_token = placeholder_token_list[str_idx]
111
- set_string, placeholder_token_ids = self.add_tokens_text_encoder_random_init(placeholder_token, num_tokens=self.num_tokens)
112
- set_string_list.append(set_string)
113
- placeholder_token_ids_list.append(placeholder_token_ids)
114
- placeholder_token_ids = list(itertools.chain(*placeholder_token_ids_list))
115
- return set_string_list, placeholder_token_ids
116
-
117
- def train_emb(
118
- self,
119
- image_gt,
120
- set_string_list,
121
- gradient_accumulation_steps = 5,
122
- embedding_learning_rate = 1e-4,
123
- max_emb_train_steps = 100,
124
- train_batch_size = 1,
125
- train_full_lora = False
126
- ):
127
- decom_controller = GroupedCAController(mask_list = self.mask_list)
128
- register_attention_disentangled_control(self.unet, decom_controller)
129
-
130
- accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
131
- self.vae.requires_grad_(False)
132
- self.unet.requires_grad_(False)
133
-
134
- self.text_encoder.requires_grad_(True)
135
- self.text_encoder_2.requires_grad_(True)
136
-
137
- self.text_encoder.text_model.encoder.requires_grad_(False)
138
- self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
139
- self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
140
-
141
- self.text_encoder_2.text_model.encoder.requires_grad_(False)
142
- self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
143
- self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
144
-
145
- weight_dtype = torch.float32
146
- if accelerator.mixed_precision == "fp16":
147
- weight_dtype = torch.float16
148
- elif accelerator.mixed_precision == "bf16":
149
- weight_dtype = torch.bfloat16
150
-
151
- self.unet.to(device, dtype=weight_dtype)
152
- self.vae.to(device, dtype=weight_dtype)
153
-
154
- trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
155
- trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
156
-
157
- optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
158
-
159
- self.text_encoder, self.text_encoder_2, optimizer = accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer)
160
-
161
- orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
162
- orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
163
-
164
- self.text_encoder.train()
165
- self.text_encoder_2.train()
166
-
167
- effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
168
-
169
- if accelerator.is_main_process:
170
- accelerator.init_trackers("DEdit EmbSteps", config={
171
- "embedding_learning_rate": embedding_learning_rate,
172
- "text_embedding_optimization_steps": effective_emb_train_steps,
173
- })
174
- global_step = 0
175
- noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
176
- progress_bar = tqdm(range(0, effective_emb_train_steps), initial = global_step, desc="EmbSteps")
177
- latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
178
- latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
179
-
180
- for _ in range(max_emb_train_steps):
181
- with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
182
- latents = latents0.clone().detach()
183
- noise = torch.randn_like(latents)
184
- bsz = latents.shape[0]
185
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
186
- timesteps = timesteps.long()
187
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
188
- encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
189
- set_string_list,
190
- self.tokenizer,
191
- self.tokenizer_2,
192
- self.text_encoder,
193
- self.text_encoder_2,
194
- length = max_length,
195
- bsz = train_batch_size,
196
- weight_dtype = weight_dtype
197
- )
198
-
199
- model_pred = self.unet(
200
- noisy_latents,
201
- timesteps,
202
- encoder_hidden_states = encoder_hidden_states_list,
203
- cross_attention_kwargs = None,
204
- added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids},
205
- return_dict=False
206
- )[0]
207
- loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
208
- accelerator.backward(loss)
209
- optimizer.step()
210
- optimizer.zero_grad()
211
-
212
- index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
213
- index_no_updates[self.min_added_id : self.max_added_id + 1] = False
214
- with torch.no_grad():
215
- accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
216
- index_no_updates] = orig_embeds_params_1[index_no_updates]
217
-
218
- index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
219
- index_no_updates[self.min_added_id : self.max_added_id + 1] = False
220
- with torch.no_grad():
221
- accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
222
- index_no_updates] = orig_embeds_params_2[index_no_updates]
223
-
224
- logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
225
- progress_bar.set_postfix(**logs)
226
- accelerator.log(logs, step=global_step)
227
- if accelerator.sync_gradients:
228
- progress_bar.update(1)
229
- global_step += 1
230
-
231
- if global_step >= max_emb_train_steps:
232
- break
233
- accelerator.wait_for_everyone()
234
- accelerator.end_training()
235
- self.text_encoder = accelerator.unwrap_model(self.text_encoder).to(dtype = weight_dtype)
236
- self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
237
-
238
- def train_model(
239
- self,
240
- image_gt,
241
- set_string_list,
242
- gradient_accumulation_steps = 5,
243
- max_diffusion_train_steps = 100,
244
- diffusion_model_learning_rate = 1e-5,
245
- train_batch_size = 1,
246
- train_full_lora = False,
247
- lora_rank = 4,
248
- lora_alpha = 4
249
- ):
250
- self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
251
- self.unet.ca_dim = 2048
252
- decom_controller = GroupedCAController(mask_list = self.mask_list)
253
- register_attention_disentangled_control(self.unet, decom_controller)
254
-
255
- mixed_precision = "fp16"
256
- accelerator = Accelerator(gradient_accumulation_steps = gradient_accumulation_steps, mixed_precision = mixed_precision)
257
-
258
- weight_dtype = torch.float32
259
- if accelerator.mixed_precision == "fp16":
260
- weight_dtype = torch.float16
261
- elif accelerator.mixed_precision == "bf16":
262
- weight_dtype = torch.bfloat16
263
-
264
- self.vae.requires_grad_(False)
265
- self.vae.to(device, dtype=weight_dtype)
266
-
267
- self.unet.requires_grad_(False)
268
- self.unet.train()
269
-
270
- self.text_encoder.requires_grad_(False)
271
- self.text_encoder_2.requires_grad_(False)
272
-
273
- if not train_full_lora:
274
- trainable_params_list = []
275
- for _, module in self.unet.named_modules():
276
- module_name = type(module).__name__
277
- if module_name == "Attention":
278
- if module.to_k.in_features == 2048: # this is cross attention:
279
- module.to_k.weight.requires_grad = True
280
- trainable_params_list.append(module.to_k.weight)
281
- if module.to_k.bias is not None:
282
- module.to_k.bias.requires_grad = True
283
- trainable_params_list.append(module.to_k.bias)
284
- module.to_v.weight.requires_grad = True
285
- trainable_params_list.append(module.to_v.weight)
286
- if module.to_v.bias is not None:
287
- module.to_v.bias.requires_grad = True
288
- trainable_params_list.append(module.to_v.bias)
289
- module.to_q.weight.requires_grad = True
290
- trainable_params_list.append(module.to_q.weight)
291
- if module.to_q.bias is not None:
292
- module.to_q.bias.requires_grad = True
293
- trainable_params_list.append(module.to_q.bias)
294
- else:
295
- unet_lora_config = LoraConfig(
296
- r=lora_rank,
297
- lora_alpha=lora_alpha,
298
- init_lora_weights="gaussian",
299
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
300
- )
301
- self.unet.add_adapter(unet_lora_config)
302
- print("training full parameters using lora!")
303
- trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
304
-
305
- self.text_encoder.to(device, dtype=weight_dtype)
306
- self.text_encoder_2.to(device, dtype=weight_dtype)
307
- optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
308
- self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
309
- psum2 = sum(p.numel() for p in trainable_params_list)
310
-
311
- effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
312
- if accelerator.is_main_process:
313
- accelerator.init_trackers("textual_inversion", config={
314
- "diffusion_model_learning_rate": diffusion_model_learning_rate,
315
- "diffusion_model_optimization_steps": effective_diffusion_train_steps,
316
- })
317
-
318
- global_step = 0
319
- progress_bar = tqdm( range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
320
-
321
- noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0" , subfolder="scheduler")
322
-
323
- latents0 = image2latent(image_gt, vae = self.vae, dtype=weight_dtype)
324
- latents0 = latents0.repeat(train_batch_size, 1, 1, 1)
325
-
326
- with torch.no_grad():
327
- encoder_hidden_states_list, add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
328
- set_string_list,
329
- self.tokenizer,
330
- self.tokenizer_2,
331
- self.text_encoder,
332
- self.text_encoder_2,
333
- length = max_length,
334
- bsz = train_batch_size,
335
- weight_dtype = weight_dtype
336
- )
337
-
338
- for _ in range(max_diffusion_train_steps):
339
- with accelerator.accumulate(self.unet):
340
- latents = latents0.clone().detach()
341
- noise = torch.randn_like(latents)
342
- bsz = latents.shape[0]
343
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
344
- timesteps = timesteps.long()
345
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
346
- model_pred = self.unet(
347
- noisy_latents,
348
- timesteps,
349
- encoder_hidden_states=encoder_hidden_states_list,
350
- cross_attention_kwargs=None, return_dict=False,
351
- added_cond_kwargs={"text_embeds": add_text_embeds, "time_ids": add_time_ids}
352
- )[0]
353
- loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
354
- accelerator.backward(loss)
355
- optimizer.step()
356
- optimizer.zero_grad()
357
-
358
- logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
359
- progress_bar.set_postfix(**logs)
360
- accelerator.log(logs, step=global_step)
361
- if accelerator.sync_gradients:
362
- progress_bar.update(1)
363
- global_step += 1
364
- if global_step >=max_diffusion_train_steps:
365
- break
366
- accelerator.wait_for_everyone()
367
- accelerator.end_training()
368
- self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
369
-
370
- def train_emb_2imgs(
371
- self,
372
- image_gt_1,
373
- image_gt_2,
374
- set_string_list_1,
375
- set_string_list_2,
376
- gradient_accumulation_steps = 5,
377
- embedding_learning_rate = 1e-4,
378
- max_emb_train_steps = 100,
379
- train_batch_size = 1,
380
- train_full_lora = False
381
- ):
382
- decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
383
- decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
384
- accelerator = Accelerator(mixed_precision=self.mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps)
385
- self.vae.requires_grad_(False)
386
- self.unet.requires_grad_(False)
387
-
388
- self.text_encoder.requires_grad_(True)
389
- self.text_encoder_2.requires_grad_(True)
390
-
391
- self.text_encoder.text_model.encoder.requires_grad_(False)
392
- self.text_encoder.text_model.final_layer_norm.requires_grad_(False)
393
- self.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
394
-
395
- self.text_encoder_2.text_model.encoder.requires_grad_(False)
396
- self.text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
397
- self.text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
398
-
399
- weight_dtype = torch.float32
400
- if accelerator.mixed_precision == "fp16":
401
- weight_dtype = torch.float16
402
- elif accelerator.mixed_precision == "bf16":
403
- weight_dtype = torch.bfloat16
404
-
405
- self.unet.to(device, dtype=weight_dtype)
406
- self.vae.to(device, dtype=weight_dtype)
407
-
408
-
409
- trainable_embmat_list_1 = [param for param in self.text_encoder.get_input_embeddings().parameters()]
410
- trainable_embmat_list_2 = [param for param in self.text_encoder_2.get_input_embeddings().parameters()]
411
-
412
- optimizer = torch.optim.AdamW(trainable_embmat_list_1 + trainable_embmat_list_2, lr=embedding_learning_rate)
413
- self.text_encoder, self.text_encoder_2, optimizer= accelerator.prepare(self.text_encoder, self.text_encoder_2, optimizer) ###
414
- orig_embeds_params_1 = accelerator.unwrap_model(self.text_encoder) .get_input_embeddings().weight.data.clone()
415
- orig_embeds_params_2 = accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight.data.clone()
416
-
417
- self.text_encoder.train()
418
- self.text_encoder_2.train()
419
-
420
- effective_emb_train_steps = max_emb_train_steps//gradient_accumulation_steps
421
-
422
- if accelerator.is_main_process:
423
- accelerator.init_trackers("EmbFt", config={
424
- "embedding_learning_rate": embedding_learning_rate,
425
- "text_embedding_optimization_steps": effective_emb_train_steps,
426
- })
427
-
428
- global_step = 0
429
-
430
- noise_scheduler = DDPMScheduler.from_pretrained(self.model_id , subfolder="scheduler")
431
- progress_bar = tqdm(range(0, effective_emb_train_steps),initial=global_step,desc="EmbSteps")
432
- latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
433
- latents0_1 = latents0_1.repeat(train_batch_size,1,1,1)
434
-
435
- latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
436
- latents0_2 = latents0_2.repeat(train_batch_size,1,1,1)
437
-
438
- for step in range(max_emb_train_steps):
439
- with accelerator.accumulate(self.text_encoder, self.text_encoder_2):
440
- latents_1 = latents0_1.clone().detach()
441
- noise_1 = torch.randn_like(latents_1)
442
-
443
- latents_2 = latents0_2.clone().detach()
444
- noise_2 = torch.randn_like(latents_2)
445
-
446
- bsz = latents_1.shape[0]
447
-
448
- timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
449
- timesteps_1 = timesteps_1.long()
450
- noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
451
-
452
- timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
453
- timesteps_2 = timesteps_2.long()
454
- noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
455
-
456
- register_attention_disentangled_control(self.unet, decom_controller_1)
457
- encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
458
- set_string_list_1,
459
- self.tokenizer,
460
- self.tokenizer_2,
461
- self.text_encoder,
462
- self.text_encoder_2,
463
- length = max_length,
464
- bsz = train_batch_size,
465
- weight_dtype = weight_dtype
466
- )
467
-
468
- model_pred_1 = self.unet(
469
- noisy_latents_1,
470
- timesteps_1,
471
- encoder_hidden_states=encoder_hidden_states_list_1,
472
- cross_attention_kwargs=None,
473
- added_cond_kwargs={"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1},
474
- return_dict=False
475
- )[0]
476
-
477
- register_attention_disentangled_control(self.unet, decom_controller_2)
478
- # import pdb; pdb.set_trace()
479
- encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
480
- set_string_list_2,
481
- self.tokenizer,
482
- self.tokenizer_2,
483
- self.text_encoder,
484
- self.text_encoder_2,
485
- length = max_length,
486
- bsz = train_batch_size,
487
- weight_dtype = weight_dtype
488
- )
489
-
490
- model_pred_2 = self.unet(
491
- noisy_latents_2,
492
- timesteps_2,
493
- encoder_hidden_states = encoder_hidden_states_list_2,
494
- cross_attention_kwargs=None,
495
- added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2},
496
- return_dict=False
497
- )[0]
498
-
499
- loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean") /2
500
- loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean") /2
501
- loss = loss_1 + loss_2
502
- accelerator.backward(loss)
503
- optimizer.step()
504
- optimizer.zero_grad()
505
-
506
- index_no_updates = torch.ones((len(self.tokenizer),), dtype=torch.bool)
507
- index_no_updates[self.min_added_id : self.max_added_id + 1] = False
508
- with torch.no_grad():
509
- accelerator.unwrap_model(self.text_encoder).get_input_embeddings().weight[
510
- index_no_updates] = orig_embeds_params_1[index_no_updates]
511
- index_no_updates = torch.ones((len(self.tokenizer_2),), dtype=torch.bool)
512
- index_no_updates[self.min_added_id : self.max_added_id + 1] = False
513
- with torch.no_grad():
514
- accelerator.unwrap_model(self.text_encoder_2).get_input_embeddings().weight[
515
- index_no_updates] = orig_embeds_params_2[index_no_updates]
516
-
517
- logs = {"loss": loss.detach().item(), "lr": embedding_learning_rate}
518
- progress_bar.set_postfix(**logs)
519
- accelerator.log(logs, step=global_step)
520
- if accelerator.sync_gradients:
521
- progress_bar.update(1)
522
- global_step += 1
523
-
524
- if global_step >= max_emb_train_steps:
525
- break
526
- accelerator.wait_for_everyone()
527
- accelerator.end_training()
528
- self.text_encoder = accelerator.unwrap_model(self.text_encoder) .to(dtype = weight_dtype)
529
- self.text_encoder_2 = accelerator.unwrap_model(self.text_encoder_2).to(dtype = weight_dtype)
530
-
531
- def train_model_2imgs(
532
- self,
533
- image_gt_1,
534
- image_gt_2,
535
- set_string_list_1,
536
- set_string_list_2,
537
- gradient_accumulation_steps = 5,
538
- max_diffusion_train_steps = 100,
539
- diffusion_model_learning_rate = 1e-5,
540
- train_batch_size = 1,
541
- train_full_lora = False,
542
- lora_rank = 4,
543
- lora_alpha = 4
544
- ):
545
- self.unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet").to(device)
546
- self.unet.ca_dim = 2048
547
- decom_controller_1 = GroupedCAController(mask_list = self.mask_list)
548
- decom_controller_2 = GroupedCAController(mask_list = self.mask_list_2)
549
-
550
- mixed_precision = "fp16"
551
- accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,mixed_precision=mixed_precision)
552
-
553
- weight_dtype = torch.float32
554
- if accelerator.mixed_precision == "fp16":
555
- weight_dtype = torch.float16
556
- elif accelerator.mixed_precision == "bf16":
557
- weight_dtype = torch.bfloat16
558
-
559
-
560
- self.vae.requires_grad_(False)
561
- self.vae.to(device, dtype=weight_dtype)
562
- self.unet.requires_grad_(False)
563
- self.unet.train()
564
-
565
- self.text_encoder.requires_grad_(False)
566
- self.text_encoder_2.requires_grad_(False)
567
- if not train_full_lora:
568
- trainable_params_list = []
569
- for name, module in self.unet.named_modules():
570
- module_name = type(module).__name__
571
- if module_name == "Attention":
572
- if module.to_k.in_features == 2048: # this is cross attention:
573
- module.to_k.weight.requires_grad = True
574
- trainable_params_list.append(module.to_k.weight)
575
- if module.to_k.bias is not None:
576
- module.to_k.bias.requires_grad = True
577
- trainable_params_list.append(module.to_k.bias)
578
-
579
- module.to_v.weight.requires_grad = True
580
- trainable_params_list.append(module.to_v.weight)
581
- if module.to_v.bias is not None:
582
- module.to_v.bias.requires_grad = True
583
- trainable_params_list.append(module.to_v.bias)
584
- module.to_q.weight.requires_grad = True
585
- trainable_params_list.append(module.to_q.weight)
586
- if module.to_q.bias is not None:
587
- module.to_q.bias.requires_grad = True
588
- trainable_params_list.append(module.to_q.bias)
589
- else:
590
- unet_lora_config = LoraConfig(
591
- r = lora_rank,
592
- lora_alpha = lora_alpha,
593
- init_lora_weights="gaussian",
594
- target_modules=["to_k", "to_q", "to_v", "to_out.0"],
595
- )
596
- self.unet.add_adapter(unet_lora_config)
597
- print("training full parameters using lora!")
598
- trainable_params_list = list(filter(lambda p: p.requires_grad, self.unet.parameters()))
599
-
600
- self.text_encoder.to(device, dtype=weight_dtype)
601
- self.text_encoder_2.to(device, dtype=weight_dtype)
602
- optimizer = torch.optim.AdamW(trainable_params_list, lr=diffusion_model_learning_rate)
603
- self.unet, optimizer = accelerator.prepare(self.unet, optimizer)
604
- psum2 = sum(p.numel() for p in trainable_params_list)
605
-
606
- effective_diffusion_train_steps = max_diffusion_train_steps // gradient_accumulation_steps
607
- if accelerator.is_main_process:
608
- accelerator.init_trackers("ModelFt", config={
609
- "diffusion_model_learning_rate": diffusion_model_learning_rate,
610
- "diffusion_model_optimization_steps": effective_diffusion_train_steps,
611
- })
612
-
613
- global_step = 0
614
- progress_bar = tqdm(range(0, effective_diffusion_train_steps),initial=global_step, desc="ModelSteps")
615
- noise_scheduler = DDPMScheduler.from_pretrained(self.model_id, subfolder="scheduler")
616
-
617
- latents0_1 = image2latent(image_gt_1, vae = self.vae, dtype=weight_dtype)
618
- latents0_1 = latents0_1.repeat(train_batch_size, 1, 1, 1)
619
-
620
- latents0_2 = image2latent(image_gt_2, vae = self.vae, dtype=weight_dtype)
621
- latents0_2 = latents0_2.repeat(train_batch_size,1, 1, 1)
622
-
623
- with torch.no_grad():
624
- encoder_hidden_states_list_1, add_text_embeds_1, add_time_ids_1 = sdxl_prepare_input_decom(
625
- set_string_list_1,
626
- self.tokenizer,
627
- self.tokenizer_2,
628
- self.text_encoder,
629
- self.text_encoder_2,
630
- length = max_length,
631
- bsz = train_batch_size,
632
- weight_dtype = weight_dtype
633
- )
634
- encoder_hidden_states_list_2, add_text_embeds_2, add_time_ids_2 = sdxl_prepare_input_decom(
635
- set_string_list_2,
636
- self.tokenizer,
637
- self.tokenizer_2,
638
- self.text_encoder,
639
- self.text_encoder_2,
640
- length = max_length,
641
- bsz = train_batch_size,
642
- weight_dtype = weight_dtype
643
- )
644
-
645
- for _ in range(max_diffusion_train_steps):
646
- with accelerator.accumulate(self.unet):
647
- latents_1 = latents0_1.clone().detach()
648
- noise_1 = torch.randn_like(latents_1)
649
- bsz = latents_1.shape[0]
650
- timesteps_1 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_1.device)
651
- timesteps_1 = timesteps_1.long()
652
- noisy_latents_1 = noise_scheduler.add_noise(latents_1, noise_1, timesteps_1)
653
-
654
- latents_2 = latents0_2.clone().detach()
655
- noise_2 = torch.randn_like(latents_2)
656
- bsz = latents_2.shape[0]
657
- timesteps_2 = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents_2.device)
658
- timesteps_2 = timesteps_2.long()
659
- noisy_latents_2 = noise_scheduler.add_noise(latents_2, noise_2, timesteps_2)
660
-
661
- register_attention_disentangled_control(self.unet, decom_controller_1)
662
- model_pred_1 = self.unet(
663
- noisy_latents_1,
664
- timesteps_1,
665
- encoder_hidden_states = encoder_hidden_states_list_1,
666
- cross_attention_kwargs = None,
667
- return_dict = False,
668
- added_cond_kwargs = {"text_embeds": add_text_embeds_1, "time_ids": add_time_ids_1}
669
- )[0]
670
-
671
- register_attention_disentangled_control(self.unet, decom_controller_2)
672
- model_pred_2 = self.unet(
673
- noisy_latents_2,
674
- timesteps_2,
675
- encoder_hidden_states = encoder_hidden_states_list_2,
676
- cross_attention_kwargs = None,
677
- return_dict=False,
678
- added_cond_kwargs={"text_embeds": add_text_embeds_2, "time_ids": add_time_ids_2}
679
- )[0]
680
-
681
- loss_1 = F.mse_loss(model_pred_1.float(), noise_1.float(), reduction="mean")
682
- loss_2 = F.mse_loss(model_pred_2.float(), noise_2.float(), reduction="mean")
683
- loss = loss_1 + loss_2
684
- accelerator.backward(loss)
685
- optimizer.step()
686
- optimizer.zero_grad()
687
-
688
-
689
- logs = {"loss": loss.detach().item(), "lr": diffusion_model_learning_rate}
690
- progress_bar.set_postfix(**logs)
691
- accelerator.log(logs, step=global_step)
692
- if accelerator.sync_gradients:
693
- progress_bar.update(1)
694
- global_step += 1
695
-
696
- if global_step >=max_diffusion_train_steps:
697
- break
698
- accelerator.wait_for_everyone()
699
- accelerator.end_training()
700
- self.unet = accelerator.unwrap_model(self.unet).to(dtype = weight_dtype)
701
-
702
- @torch.no_grad()
703
- def backward_zT_to_z0_euler_decom(
704
- self,
705
- zT,
706
- cond_emb_list,
707
- cond_add_text_embeds,
708
- add_time_ids,
709
- uncond_emb=None,
710
- guidance_scale = 1,
711
- num_sampling_steps = 20,
712
- cond_controller = None,
713
- uncond_controller = None,
714
- mask_hard = None,
715
- mask_soft = None,
716
- orig_image = None,
717
- return_intermediate = False,
718
- strength = 1
719
- ):
720
- latent_cur = zT
721
- if uncond_emb is None:
722
- uncond_emb = torch.zeros(zT.shape[0], 77, 2048).to(dtype = zT.dtype, device = zT.device)
723
- uncond_add_text_embeds = torch.zeros(1, 1280).to(dtype = zT.dtype, device = zT.device)
724
- if mask_soft is not None:
725
- init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
726
- length = init_latents_orig.shape[-1]
727
- noise = torch.randn_like(init_latents_orig)
728
- mask_soft = torch.nn.functional.interpolate(mask_soft.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
729
- if mask_hard is not None:
730
- init_latents_orig = image2latent(orig_image, self.vae, dtype=self.vae.dtype)
731
- length = init_latents_orig.shape[-1]
732
- noise = torch.randn_like(init_latents_orig)
733
- mask_hard = torch.nn.functional.interpolate(mask_hard.float().unsqueeze(0).unsqueeze(0), (length, length)).to(self.vae.dtype) ###
734
-
735
- intermediate_list = [latent_cur.detach()]
736
- for i in tqdm(range(num_sampling_steps)):
737
- t = self.scheduler.timesteps[i]
738
- latent_input = self.scheduler.scale_model_input(latent_cur, t)
739
-
740
- register_attention_disentangled_control(self.unet, uncond_controller)
741
- noise_pred_uncond = self.unet(latent_input, t,
742
- encoder_hidden_states=uncond_emb,
743
- added_cond_kwargs={"text_embeds": uncond_add_text_embeds, "time_ids": add_time_ids},
744
- return_dict=False,)[0]
745
-
746
- register_attention_disentangled_control(self.unet, cond_controller)
747
- noise_pred_cond = self.unet(latent_input, t,
748
- encoder_hidden_states=cond_emb_list,
749
- added_cond_kwargs={"text_embeds": cond_add_text_embeds, "time_ids": add_time_ids},
750
- return_dict=False,)[0]
751
-
752
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
753
- latent_cur = self.scheduler.step(noise_pred, t, latent_cur, generator = None, return_dict=False)[0]
754
- if return_intermediate is True:
755
- intermediate_list.append(latent_cur)
756
- if mask_hard is not None and mask_soft is not None and i <= strength *num_sampling_steps:
757
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
758
- mask = mask_soft.to(latent_cur.device, latent_cur.dtype) + mask_hard.to(latent_cur.device, latent_cur.dtype)
759
- latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
760
-
761
- elif mask_hard is not None and mask_soft is not None and i > strength *num_sampling_steps:
762
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
763
- mask = mask_hard.to(latent_cur.device, latent_cur.dtype)
764
- latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
765
-
766
- elif mask_hard is None and mask_soft is not None and i <= strength *num_sampling_steps:
767
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
768
- mask = mask_soft.to(latent_cur.device, latent_cur.dtype)
769
- latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
770
-
771
- elif mask_hard is None and mask_soft is not None and i > strength *num_sampling_steps:
772
- pass
773
-
774
- elif mask_hard is not None and mask_soft is None:
775
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
776
- mask = mask_hard.to(latent_cur.dtype)
777
- latent_cur = (init_latents_proper * mask) + (latent_cur * (1 - mask))
778
-
779
- else: # hard and soft are both none
780
- pass
781
-
782
- if return_intermediate is True:
783
- return latent_cur, intermediate_list
784
- else:
785
- return latent_cur
786
-
787
- @torch.no_grad()
788
- def sampling(
789
- self,
790
- set_string_list,
791
- cond_controller = None,
792
- uncond_controller = None,
793
- guidance_scale = 7,
794
- num_sampling_steps = 20,
795
- mask_hard = None,
796
- mask_soft = None,
797
- orig_image = None,
798
- strength = 1.,
799
- num_imgs = 1,
800
- normal_token_id_list = [],
801
- seed = 1
802
- ):
803
- weight_dtype = torch.float16
804
- self.scheduler.set_timesteps(num_sampling_steps)
805
- self.unet.to(device, dtype=weight_dtype)
806
- self.vae.to(device, dtype=weight_dtype)
807
- self.text_encoder.to(device, dtype=weight_dtype)
808
- self.text_encoder_2.to(device, dtype=weight_dtype)
809
- torch.manual_seed(seed)
810
- torch.cuda.manual_seed(seed)
811
-
812
- vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
813
- zT = torch.randn(num_imgs, 4, self.resolution//vae_scale_factor,self.resolution//vae_scale_factor).to(device,dtype=weight_dtype)
814
- zT = zT * self.scheduler.init_noise_sigma
815
-
816
- cond_emb_list, cond_add_text_embeds, add_time_ids = sdxl_prepare_input_decom(
817
- set_string_list,
818
- self.tokenizer,
819
- self.tokenizer_2,
820
- self.text_encoder,
821
- self.text_encoder_2,
822
- length = max_length,
823
- bsz = num_imgs,
824
- weight_dtype = weight_dtype,
825
- normal_token_id_list = normal_token_id_list
826
- )
827
-
828
- z0 = self.backward_zT_to_z0_euler_decom(zT, cond_emb_list, cond_add_text_embeds, add_time_ids,
829
- guidance_scale = guidance_scale, num_sampling_steps = num_sampling_steps,
830
- cond_controller = cond_controller, uncond_controller = uncond_controller,
831
- mask_hard = mask_hard, mask_soft = mask_soft, orig_image =orig_image, strength = strength
832
- )
833
- x0 = latent2image(z0, vae = self.vae)
834
- return x0
835
-
836
- @torch.no_grad()
837
- def inference_with_mask(
838
- self,
839
- save_path,
840
- guidance_scale = 3,
841
- num_sampling_steps = 50,
842
- strength = 1,
843
- mask_soft = None,
844
- mask_hard= None,
845
- orig_image=None,
846
- mask_list = None,
847
- num_imgs = 1,
848
- seed = 1,
849
- set_string_list = None
850
- ):
851
- if mask_list is not None:
852
- mask_list = [m.to(device) for m in mask_list]
853
- else:
854
- mask_list = self.mask_list
855
- if set_string_list is not None:
856
- self.set_string_list = set_string_list
857
-
858
- if mask_hard is not None and mask_soft is not None:
859
- check_mask_overlap_torch(mask_hard, mask_soft)
860
- null_controller = DummyController()
861
- decom_controller = GroupedCAController(mask_list = mask_list)
862
- x0 = self.sampling(
863
- self.set_string_list,
864
- guidance_scale = guidance_scale,
865
- num_sampling_steps = num_sampling_steps,
866
- strength = strength,
867
- cond_controller = decom_controller,
868
- uncond_controller = null_controller,
869
- mask_soft = mask_soft,
870
- mask_hard = mask_hard,
871
- orig_image = orig_image,
872
- num_imgs = num_imgs,
873
- seed = seed
874
- )
875
- save_images(x0, save_path)