twodgirl commited on
Commit
ca2a32e
1 Parent(s): d2f5467

Upload checkpoint to diffusers converter.

Browse files
Files changed (1) hide show
  1. convert_sd3_to_diffusers.py +248 -0
convert_sd3_to_diffusers.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from contextlib import nullcontext
3
+
4
+ import safetensors.torch
5
+ import torch
6
+ from accelerate import init_empty_weights
7
+
8
+ from diffusers import AutoencoderKL, SD3Transformer2DModel
9
+ from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
10
+ from diffusers.models.modeling_utils import load_model_dict_into_meta
11
+ from diffusers.utils.import_utils import is_accelerate_available
12
+
13
+
14
+ CTX = init_empty_weights if is_accelerate_available else nullcontext
15
+
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("--checkpoint_path", type=str)
18
+ parser.add_argument("--output_path", type=str)
19
+ parser.add_argument("--dtype", type=str, default="fp16")
20
+
21
+ args = parser.parse_args()
22
+ dtype = torch.float16 if args.dtype == "fp16" else torch.float32
23
+
24
+
25
+ def load_original_checkpoint(ckpt_path):
26
+ original_state_dict = safetensors.torch.load_file(ckpt_path)
27
+ keys = list(original_state_dict.keys())
28
+ for k in keys:
29
+ if "model.diffusion_model." in k:
30
+ original_state_dict[k.replace("model.diffusion_model.", "")] = original_state_dict.pop(k)
31
+
32
+ return original_state_dict
33
+
34
+
35
+ # in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
36
+ # while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
37
+ def swap_scale_shift(weight, dim):
38
+ shift, scale = weight.chunk(2, dim=0)
39
+ new_weight = torch.cat([scale, shift], dim=0)
40
+ return new_weight
41
+
42
+
43
+ def convert_sd3_transformer_checkpoint_to_diffusers(original_state_dict, num_layers, caption_projection_dim):
44
+ converted_state_dict = {}
45
+
46
+ # Positional and patch embeddings.
47
+ converted_state_dict["pos_embed.pos_embed"] = original_state_dict.pop("pos_embed")
48
+ converted_state_dict["pos_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
49
+ converted_state_dict["pos_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
50
+
51
+ # Timestep embeddings.
52
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
53
+ "t_embedder.mlp.0.weight"
54
+ )
55
+ converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
56
+ "t_embedder.mlp.0.bias"
57
+ )
58
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
59
+ "t_embedder.mlp.2.weight"
60
+ )
61
+ converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
62
+ "t_embedder.mlp.2.bias"
63
+ )
64
+
65
+ # Context projections.
66
+ converted_state_dict["context_embedder.weight"] = original_state_dict.pop("context_embedder.weight")
67
+ converted_state_dict["context_embedder.bias"] = original_state_dict.pop("context_embedder.bias")
68
+
69
+ # Pooled context projection.
70
+ converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
71
+ "y_embedder.mlp.0.weight"
72
+ )
73
+ converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
74
+ "y_embedder.mlp.0.bias"
75
+ )
76
+ converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
77
+ "y_embedder.mlp.2.weight"
78
+ )
79
+ converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
80
+ "y_embedder.mlp.2.bias"
81
+ )
82
+
83
+ # Transformer blocks 🎸.
84
+ for i in range(num_layers):
85
+ # Q, K, V
86
+ sample_q, sample_k, sample_v = torch.chunk(
87
+ original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
88
+ )
89
+ context_q, context_k, context_v = torch.chunk(
90
+ original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
91
+ )
92
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
93
+ original_state_dict.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
94
+ )
95
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
96
+ original_state_dict.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
97
+ )
98
+
99
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
100
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
101
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
102
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
103
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
104
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
105
+
106
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
107
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
108
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
109
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
110
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
111
+ converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
112
+
113
+ # output projections.
114
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = original_state_dict.pop(
115
+ f"joint_blocks.{i}.x_block.attn.proj.weight"
116
+ )
117
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = original_state_dict.pop(
118
+ f"joint_blocks.{i}.x_block.attn.proj.bias"
119
+ )
120
+ if not (i == num_layers - 1):
121
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = original_state_dict.pop(
122
+ f"joint_blocks.{i}.context_block.attn.proj.weight"
123
+ )
124
+ converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = original_state_dict.pop(
125
+ f"joint_blocks.{i}.context_block.attn.proj.bias"
126
+ )
127
+
128
+ # norms.
129
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = original_state_dict.pop(
130
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
131
+ )
132
+ converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = original_state_dict.pop(
133
+ f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
134
+ )
135
+ if not (i == num_layers - 1):
136
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = original_state_dict.pop(
137
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
138
+ )
139
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = original_state_dict.pop(
140
+ f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
141
+ )
142
+ else:
143
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
144
+ original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
145
+ dim=caption_projection_dim,
146
+ )
147
+ converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
148
+ original_state_dict.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
149
+ dim=caption_projection_dim,
150
+ )
151
+
152
+ # ffs.
153
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = original_state_dict.pop(
154
+ f"joint_blocks.{i}.x_block.mlp.fc1.weight"
155
+ )
156
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = original_state_dict.pop(
157
+ f"joint_blocks.{i}.x_block.mlp.fc1.bias"
158
+ )
159
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = original_state_dict.pop(
160
+ f"joint_blocks.{i}.x_block.mlp.fc2.weight"
161
+ )
162
+ converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = original_state_dict.pop(
163
+ f"joint_blocks.{i}.x_block.mlp.fc2.bias"
164
+ )
165
+ if not (i == num_layers - 1):
166
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = original_state_dict.pop(
167
+ f"joint_blocks.{i}.context_block.mlp.fc1.weight"
168
+ )
169
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = original_state_dict.pop(
170
+ f"joint_blocks.{i}.context_block.mlp.fc1.bias"
171
+ )
172
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = original_state_dict.pop(
173
+ f"joint_blocks.{i}.context_block.mlp.fc2.weight"
174
+ )
175
+ converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = original_state_dict.pop(
176
+ f"joint_blocks.{i}.context_block.mlp.fc2.bias"
177
+ )
178
+
179
+ # Final blocks.
180
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
181
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
182
+ converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
183
+ original_state_dict.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
184
+ )
185
+ converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
186
+ original_state_dict.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
187
+ )
188
+
189
+ return converted_state_dict
190
+
191
+
192
+ def is_vae_in_checkpoint(original_state_dict):
193
+ return ("first_stage_model.decoder.conv_in.weight" in original_state_dict) and (
194
+ "first_stage_model.encoder.conv_in.weight" in original_state_dict
195
+ )
196
+
197
+
198
+ def main(args):
199
+ original_ckpt = load_original_checkpoint(args.checkpoint_path)
200
+ num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401
201
+ caption_projection_dim = 1536
202
+
203
+ converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
204
+ original_ckpt, num_layers, caption_projection_dim
205
+ )
206
+
207
+ with CTX():
208
+ transformer = SD3Transformer2DModel(
209
+ sample_size=64,
210
+ patch_size=2,
211
+ in_channels=16,
212
+ joint_attention_dim=4096,
213
+ num_layers=num_layers,
214
+ caption_projection_dim=caption_projection_dim,
215
+ num_attention_heads=24,
216
+ pos_embed_max_size=192,
217
+ )
218
+ if is_accelerate_available():
219
+ load_model_dict_into_meta(transformer, converted_transformer_state_dict)
220
+ else:
221
+ transformer.load_state_dict(converted_transformer_state_dict, strict=True)
222
+
223
+ print("Saving SD3 Transformer in Diffusers format.")
224
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
225
+
226
+ if is_vae_in_checkpoint(original_ckpt):
227
+ with CTX():
228
+ vae = AutoencoderKL.from_config(
229
+ "stabilityai/stable-diffusion-xl-base-1.0",
230
+ subfolder="vae",
231
+ latent_channels=16,
232
+ use_post_quant_conv=False,
233
+ use_quant_conv=False,
234
+ scaling_factor=1.5305,
235
+ shift_factor=0.0609,
236
+ )
237
+ converted_vae_state_dict = convert_ldm_vae_checkpoint(original_ckpt, vae.config)
238
+ if is_accelerate_available():
239
+ load_model_dict_into_meta(vae, converted_vae_state_dict)
240
+ else:
241
+ vae.load_state_dict(converted_vae_state_dict, strict=True)
242
+
243
+ print("Saving SD3 Autoencoder in Diffusers format.")
244
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
245
+
246
+
247
+ if __name__ == "__main__":
248
+ main(args)