Koke_Cacao commited on
Commit
1a29d24
1 Parent(s): 9db01cb

:bug: fix for sd 2.1

Browse files
scripts/convert_mvdream_to_diffusers.py CHANGED
@@ -27,105 +27,6 @@ from transformers import CLIPTokenizer, CLIPTextModel
27
 
28
  logger = logging.get_logger(__name__)
29
 
30
- # def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
31
- # """
32
- # Creates a config for the diffusers based on the config of the LDM model.
33
- # """
34
- # if controlnet:
35
- # unet_params = original_config.model.params.control_stage_config.params
36
- # else:
37
- # if "unet_config" in original_config.model.params and original_config.model.params.unet_config is not None:
38
- # unet_params = original_config.model.params.unet_config.params
39
- # else:
40
- # unet_params = original_config.model.params.network_config.params
41
-
42
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
43
-
44
- # block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
45
-
46
- # down_block_types = []
47
- # resolution = 1
48
- # for i in range(len(block_out_channels)):
49
- # block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
50
- # down_block_types.append(block_type)
51
- # if i != len(block_out_channels) - 1:
52
- # resolution *= 2
53
-
54
- # up_block_types = []
55
- # for i in range(len(block_out_channels)):
56
- # block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
57
- # up_block_types.append(block_type)
58
- # resolution //= 2
59
-
60
- # if unet_params.transformer_depth is not None:
61
- # transformer_layers_per_block = (
62
- # unet_params.transformer_depth
63
- # if isinstance(unet_params.transformer_depth, int)
64
- # else list(unet_params.transformer_depth)
65
- # )
66
- # else:
67
- # transformer_layers_per_block = 1
68
-
69
- # vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
70
-
71
- # head_dim = unet_params.num_heads if "num_heads" in unet_params else None
72
- # use_linear_projection = (
73
- # unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
74
- # )
75
- # if use_linear_projection:
76
- # # stable diffusion 2-base-512 and 2-768
77
- # if head_dim is None:
78
- # head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
79
- # head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
80
-
81
- # class_embed_type = None
82
- # addition_embed_type = None
83
- # addition_time_embed_dim = None
84
- # projection_class_embeddings_input_dim = None
85
- # context_dim = None
86
-
87
- # if unet_params.context_dim is not None:
88
- # context_dim = (
89
- # unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
90
- # )
91
-
92
- # if "num_classes" in unet_params:
93
- # if unet_params.num_classes == "sequential":
94
- # if context_dim in [2048, 1280]:
95
- # # SDXL
96
- # addition_embed_type = "text_time"
97
- # addition_time_embed_dim = 256
98
- # else:
99
- # class_embed_type = "projection"
100
- # assert "adm_in_channels" in unet_params
101
- # projection_class_embeddings_input_dim = unet_params.adm_in_channels
102
- # else:
103
- # raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
104
-
105
- # config = {
106
- # "sample_size": image_size // vae_scale_factor,
107
- # "in_channels": unet_params.in_channels,
108
- # "down_block_types": tuple(down_block_types),
109
- # "block_out_channels": tuple(block_out_channels),
110
- # "layers_per_block": unet_params.num_res_blocks,
111
- # "cross_attention_dim": context_dim,
112
- # "attention_head_dim": head_dim,
113
- # "use_linear_projection": use_linear_projection,
114
- # "class_embed_type": class_embed_type,
115
- # "addition_embed_type": addition_embed_type,
116
- # "addition_time_embed_dim": addition_time_embed_dim,
117
- # "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
118
- # "transformer_layers_per_block": transformer_layers_per_block,
119
- # }
120
-
121
- # if controlnet:
122
- # config["conditioning_channels"] = unet_params.hint_channels
123
- # else:
124
- # config["out_channels"] = unet_params.out_channels
125
- # config["up_block_types"] = tuple(up_block_types)
126
-
127
- # return config
128
-
129
 
130
  def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
131
  """
@@ -190,291 +91,6 @@ def shave_segments(path, n_shave_prefix_segments=1):
190
  return ".".join(path.split(".")[:n_shave_prefix_segments])
191
 
192
 
193
- def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
194
- """
195
- Updates paths inside resnets to the new naming scheme (local renaming)
196
- """
197
- mapping = []
198
- for old_item in old_list:
199
- new_item = old_item.replace("in_layers.0", "norm1")
200
- new_item = new_item.replace("in_layers.2", "conv1")
201
-
202
- new_item = new_item.replace("out_layers.0", "norm2")
203
- new_item = new_item.replace("out_layers.3", "conv2")
204
-
205
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
206
- new_item = new_item.replace("skip_connection", "conv_shortcut")
207
-
208
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
209
-
210
- mapping.append({"old": old_item, "new": new_item})
211
-
212
- return mapping
213
-
214
-
215
- def renew_attention_paths(old_list, n_shave_prefix_segments=0):
216
- """
217
- Updates paths inside attentions to the new naming scheme (local renaming)
218
- """
219
- mapping = []
220
- for old_item in old_list:
221
- new_item = old_item
222
-
223
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
224
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
225
-
226
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
227
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
228
-
229
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
230
-
231
- mapping.append({"old": old_item, "new": new_item})
232
-
233
- return mapping
234
-
235
-
236
- # def convert_ldm_unet_checkpoint(
237
- # checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
238
- # ):
239
- # """
240
- # Takes a state dict and a config, and returns a converted checkpoint.
241
- # """
242
-
243
- # if skip_extract_state_dict:
244
- # unet_state_dict = checkpoint
245
- # else:
246
- # # extract state_dict for UNet
247
- # unet_state_dict = {}
248
- # keys = list(checkpoint.keys())
249
-
250
- # if controlnet:
251
- # unet_key = "control_model."
252
- # else:
253
- # unet_key = "model.diffusion_model."
254
-
255
- # # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
256
- # if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
257
- # logger.warning(f"Checkpoint {path} has both EMA and non-EMA weights.")
258
- # logger.warning(
259
- # "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
260
- # " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
261
- # )
262
- # for key in keys:
263
- # if key.startswith("model.diffusion_model"):
264
- # flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
265
- # unet_state_dict[key.replace(unet_key, "")] = checkpoint[flat_ema_key]
266
- # else:
267
- # if sum(k.startswith("model_ema") for k in keys) > 100:
268
- # logger.warning(
269
- # "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
270
- # " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
271
- # )
272
-
273
- # for key in keys:
274
- # if key.startswith(unet_key):
275
- # unet_state_dict[key.replace(unet_key, "")] = checkpoint[key]
276
-
277
- # new_checkpoint = {}
278
-
279
- # new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
280
- # new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
281
- # new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
282
- # new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
283
-
284
- # if config["class_embed_type"] is None:
285
- # # No parameters to port
286
- # ...
287
- # elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
288
- # new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
289
- # new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
290
- # new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
291
- # new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
292
- # else:
293
- # raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
294
-
295
- # if config["addition_embed_type"] == "text_time":
296
- # new_checkpoint["add_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
297
- # new_checkpoint["add_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
298
- # new_checkpoint["add_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
299
- # new_checkpoint["add_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
300
-
301
- # new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
302
- # new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
303
-
304
- # if not controlnet:
305
- # new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
306
- # new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
307
- # new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
308
- # new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
309
-
310
- # # Retrieves the keys for the input blocks only
311
- # num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
312
- # input_blocks = {
313
- # layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
314
- # for layer_id in range(num_input_blocks)
315
- # }
316
-
317
- # # Retrieves the keys for the middle blocks only
318
- # num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
319
- # middle_blocks = {
320
- # layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
321
- # for layer_id in range(num_middle_blocks)
322
- # }
323
-
324
- # # Retrieves the keys for the output blocks only
325
- # num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
326
- # output_blocks = {
327
- # layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
328
- # for layer_id in range(num_output_blocks)
329
- # }
330
-
331
- # for i in range(1, num_input_blocks):
332
- # block_id = (i - 1) // (config["layers_per_block"] + 1)
333
- # layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
334
-
335
- # resnets = [
336
- # key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
337
- # ]
338
- # attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
339
-
340
- # if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
341
- # new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
342
- # f"input_blocks.{i}.0.op.weight"
343
- # )
344
- # new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
345
- # f"input_blocks.{i}.0.op.bias"
346
- # )
347
-
348
- # paths = renew_resnet_paths(resnets)
349
- # meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
350
- # assign_to_checkpoint(
351
- # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
352
- # )
353
-
354
- # if len(attentions):
355
- # paths = renew_attention_paths(attentions)
356
- # meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
357
- # assign_to_checkpoint(
358
- # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359
- # )
360
-
361
- # resnet_0 = middle_blocks[0]
362
- # attentions = middle_blocks[1]
363
- # resnet_1 = middle_blocks[2]
364
-
365
- # resnet_0_paths = renew_resnet_paths(resnet_0)
366
- # assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
367
-
368
- # resnet_1_paths = renew_resnet_paths(resnet_1)
369
- # assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
370
-
371
- # attentions_paths = renew_attention_paths(attentions)
372
- # meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
373
- # assign_to_checkpoint(
374
- # attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
375
- # )
376
-
377
- # for i in range(num_output_blocks):
378
- # block_id = i // (config["layers_per_block"] + 1)
379
- # layer_in_block_id = i % (config["layers_per_block"] + 1)
380
- # output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
381
- # output_block_list = {}
382
-
383
- # for layer in output_block_layers:
384
- # layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
385
- # if layer_id in output_block_list:
386
- # output_block_list[layer_id].append(layer_name)
387
- # else:
388
- # output_block_list[layer_id] = [layer_name]
389
-
390
- # if len(output_block_list) > 1:
391
- # resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
392
- # attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
393
-
394
- # resnet_0_paths = renew_resnet_paths(resnets)
395
- # paths = renew_resnet_paths(resnets)
396
-
397
- # meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
398
- # assign_to_checkpoint(
399
- # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
400
- # )
401
-
402
- # output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
403
- # if ["conv.bias", "conv.weight"] in output_block_list.values():
404
- # index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
405
- # new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
406
- # f"output_blocks.{i}.{index}.conv.weight"
407
- # ]
408
- # new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
409
- # f"output_blocks.{i}.{index}.conv.bias"
410
- # ]
411
-
412
- # # Clear attentions as they have been attributed above.
413
- # if len(attentions) == 2:
414
- # attentions = []
415
-
416
- # if len(attentions):
417
- # paths = renew_attention_paths(attentions)
418
- # meta_path = {
419
- # "old": f"output_blocks.{i}.1",
420
- # "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
421
- # }
422
- # assign_to_checkpoint(
423
- # paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
424
- # )
425
- # else:
426
- # resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
427
- # for path in resnet_0_paths:
428
- # old_path = ".".join(["output_blocks", str(i), path["old"]])
429
- # new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
430
-
431
- # new_checkpoint[new_path] = unet_state_dict[old_path]
432
-
433
- # if controlnet:
434
- # # conditioning embedding
435
-
436
- # orig_index = 0
437
-
438
- # new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
439
- # f"input_hint_block.{orig_index}.weight"
440
- # )
441
- # new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
442
- # f"input_hint_block.{orig_index}.bias"
443
- # )
444
-
445
- # orig_index += 2
446
-
447
- # diffusers_index = 0
448
-
449
- # while diffusers_index < 6:
450
- # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
451
- # f"input_hint_block.{orig_index}.weight"
452
- # )
453
- # new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
454
- # f"input_hint_block.{orig_index}.bias"
455
- # )
456
- # diffusers_index += 1
457
- # orig_index += 2
458
-
459
- # new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
460
- # f"input_hint_block.{orig_index}.weight"
461
- # )
462
- # new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
463
- # f"input_hint_block.{orig_index}.bias"
464
- # )
465
-
466
- # # down blocks
467
- # for i in range(num_input_blocks):
468
- # new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
469
- # new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
470
-
471
- # # mid block
472
- # new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
473
- # new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
474
-
475
- # return new_checkpoint
476
-
477
-
478
  def create_vae_diffusers_config(original_config, image_size: int):
479
  """
480
  Creates a config for the diffusers based on the config of the LDM model.
@@ -706,8 +322,14 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
706
  with init_empty_weights():
707
  vae = AutoencoderKL(**vae_config)
708
 
709
- tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
710
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=torch.device("cuda:0")) # type: ignore
 
 
 
 
 
 
711
 
712
  for param_name, param in converted_vae_checkpoint.items():
713
  set_module_tensor_to_device(vae, param_name, "cuda:0", value=param)
 
27
 
28
  logger = logging.get_logger(__name__)
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  def assign_to_checkpoint(paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None):
32
  """
 
91
  return ".".join(path.split(".")[:n_shave_prefix_segments])
92
 
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def create_vae_diffusers_config(original_config, image_size: int):
95
  """
96
  Creates a config for the diffusers based on the config of the LDM model.
 
322
  with init_empty_weights():
323
  vae = AutoencoderKL(**vae_config)
324
 
325
+ if original_config.model.params.unet_config.params.context_dim == 768:
326
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
327
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=torch.device("cuda:0")) # type: ignore
328
+ elif original_config.model.params.unet_config.params.context_dim == 1024:
329
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
330
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=torch.device("cuda:0")) # type: ignore
331
+ else:
332
+ raise ValueError(f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}")
333
 
334
  for param_name, param in converted_vae_checkpoint.items():
335
  set_module_tensor_to_device(vae, param_name, "cuda:0", value=param)