ManglerFTW's picture
Upload 12 files
3a18eba
# v1: split from train_db_fixed.py.
# v2: support safetensors
import math
import os
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
from diffusers import (
AutoencoderKL,
DDIMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from safetensors.torch import load_file, save_file
# DiffUsers版StableDiffusionのモデルパラメータ
NUM_TRAIN_TIMESTEPS = 1000
BETA_START = 0.00085
BETA_END = 0.0120
UNET_PARAMS_MODEL_CHANNELS = 320
UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
UNET_PARAMS_IMAGE_SIZE = 32 # unused
UNET_PARAMS_IN_CHANNELS = 4
UNET_PARAMS_OUT_CHANNELS = 4
UNET_PARAMS_NUM_RES_BLOCKS = 2
UNET_PARAMS_CONTEXT_DIM = 768
UNET_PARAMS_NUM_HEADS = 8
VAE_PARAMS_Z_CHANNELS = 4
VAE_PARAMS_RESOLUTION = 256
VAE_PARAMS_IN_CHANNELS = 3
VAE_PARAMS_OUT_CH = 3
VAE_PARAMS_CH = 128
VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
VAE_PARAMS_NUM_RES_BLOCKS = 2
# V2
V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
V2_UNET_PARAMS_CONTEXT_DIM = 1024
# Diffusersの設定を読み込むための参照モデル
DIFFUSERS_REF_MODEL_ID_V1 = 'runwayml/stable-diffusion-v1-5'
DIFFUSERS_REF_MODEL_ID_V2 = 'stabilityai/stable-diffusion-2-1'
# region StableDiffusion->Diffusersの変換コード
# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return '.'.join(path.split('.')[n_shave_prefix_segments:])
else:
return '.'.join(path.split('.')[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace('in_layers.0', 'norm1')
new_item = new_item.replace('in_layers.2', 'conv1')
new_item = new_item.replace('out_layers.0', 'norm2')
new_item = new_item.replace('out_layers.3', 'conv2')
new_item = new_item.replace('emb_layers.1', 'time_emb_proj')
new_item = new_item.replace('skip_connection', 'conv_shortcut')
new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({'old': old_item, 'new': new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace('nin_shortcut', 'conv_shortcut')
new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({'old': old_item, 'new': new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({'old': old_item, 'new': new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace('norm.weight', 'group_norm.weight')
new_item = new_item.replace('norm.bias', 'group_norm.bias')
new_item = new_item.replace('q.weight', 'query.weight')
new_item = new_item.replace('q.bias', 'query.bias')
new_item = new_item.replace('k.weight', 'key.weight')
new_item = new_item.replace('k.bias', 'key.bias')
new_item = new_item.replace('v.weight', 'value.weight')
new_item = new_item.replace('v.bias', 'value.bias')
new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments
)
mapping.append({'old': old_item, 'new': new_item})
return mapping
def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming
to them. It splits attention layers, and takes into account additional replacements
that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(
paths, list
), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (
(-1, channels) if len(old_tensor.shape) == 3 else (-1)
)
num_heads = old_tensor.shape[0] // config['num_head_channels'] // 3
old_tensor = old_tensor.reshape(
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
)
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map['query']] = query.reshape(target_shape)
checkpoint[path_map['key']] = key.reshape(target_shape)
checkpoint[path_map['value']] = value.reshape(target_shape)
for path in paths:
new_path = path['new']
# These have already been assigned
if (
attention_paths_to_split is not None
and new_path in attention_paths_to_split
):
continue
# Global renaming happens here
new_path = new_path.replace('middle_block.0', 'mid_block.resnets.0')
new_path = new_path.replace('middle_block.1', 'mid_block.attentions.0')
new_path = new_path.replace('middle_block.2', 'mid_block.resnets.1')
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(
replacement['old'], replacement['new']
)
# proj_attn.weight has to be converted from conv 1D to linear
if 'proj_attn.weight' in new_path:
checkpoint[new_path] = old_checkpoint[path['old']][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path['old']]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ['query.weight', 'key.weight', 'value.weight']
for key in keys:
if '.'.join(key.split('.')[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif 'proj_attn.weight' in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def linear_transformer_to_conv(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ['proj_in.weight', 'proj_out.weight']
for key in keys:
if '.'.join(key.split('.')[-2:]) in tf_keys:
if checkpoint[key].ndim == 2:
checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
def convert_ldm_unet_checkpoint(v2, checkpoint, config):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
unet_key = 'model.diffusion_model.'
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, '')] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint['time_embedding.linear_1.weight'] = unet_state_dict[
'time_embed.0.weight'
]
new_checkpoint['time_embedding.linear_1.bias'] = unet_state_dict[
'time_embed.0.bias'
]
new_checkpoint['time_embedding.linear_2.weight'] = unet_state_dict[
'time_embed.2.weight'
]
new_checkpoint['time_embedding.linear_2.bias'] = unet_state_dict[
'time_embed.2.bias'
]
new_checkpoint['conv_in.weight'] = unet_state_dict[
'input_blocks.0.0.weight'
]
new_checkpoint['conv_in.bias'] = unet_state_dict['input_blocks.0.0.bias']
new_checkpoint['conv_norm_out.weight'] = unet_state_dict['out.0.weight']
new_checkpoint['conv_norm_out.bias'] = unet_state_dict['out.0.bias']
new_checkpoint['conv_out.weight'] = unet_state_dict['out.2.weight']
new_checkpoint['conv_out.bias'] = unet_state_dict['out.2.bias']
# Retrieves the keys for the input blocks only
num_input_blocks = len(
{
'.'.join(layer.split('.')[:2])
for layer in unet_state_dict
if 'input_blocks' in layer
}
)
input_blocks = {
layer_id: [
key
for key in unet_state_dict
if f'input_blocks.{layer_id}.' in key
]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len(
{
'.'.join(layer.split('.')[:2])
for layer in unet_state_dict
if 'middle_block' in layer
}
)
middle_blocks = {
layer_id: [
key
for key in unet_state_dict
if f'middle_block.{layer_id}.' in key
]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len(
{
'.'.join(layer.split('.')[:2])
for layer in unet_state_dict
if 'output_blocks' in layer
}
)
output_blocks = {
layer_id: [
key
for key in unet_state_dict
if f'output_blocks.{layer_id}.' in key
]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config['layers_per_block'] + 1)
layer_in_block_id = (i - 1) % (config['layers_per_block'] + 1)
resnets = [
key
for key in input_blocks[i]
if f'input_blocks.{i}.0' in key
and f'input_blocks.{i}.0.op' not in key
]
attentions = [
key for key in input_blocks[i] if f'input_blocks.{i}.1' in key
]
if f'input_blocks.{i}.0.op.weight' in unet_state_dict:
new_checkpoint[
f'down_blocks.{block_id}.downsamplers.0.conv.weight'
] = unet_state_dict.pop(f'input_blocks.{i}.0.op.weight')
new_checkpoint[
f'down_blocks.{block_id}.downsamplers.0.conv.bias'
] = unet_state_dict.pop(f'input_blocks.{i}.0.op.bias')
paths = renew_resnet_paths(resnets)
meta_path = {
'old': f'input_blocks.{i}.0',
'new': f'down_blocks.{block_id}.resnets.{layer_in_block_id}',
}
assign_to_checkpoint(
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
'old': f'input_blocks.{i}.1',
'new': f'down_blocks.{block_id}.attentions.{layer_in_block_id}',
}
assign_to_checkpoint(
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(
resnet_0_paths, new_checkpoint, unet_state_dict, config=config
)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(
resnet_1_paths, new_checkpoint, unet_state_dict, config=config
)
attentions_paths = renew_attention_paths(attentions)
meta_path = {'old': 'middle_block.1', 'new': 'mid_block.attentions.0'}
assign_to_checkpoint(
attentions_paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
for i in range(num_output_blocks):
block_id = i // (config['layers_per_block'] + 1)
layer_in_block_id = i % (config['layers_per_block'] + 1)
output_block_layers = [
shave_segments(name, 2) for name in output_blocks[i]
]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split('.')[0], shave_segments(
layer, 1
)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [
key
for key in output_blocks[i]
if f'output_blocks.{i}.0' in key
]
attentions = [
key
for key in output_blocks[i]
if f'output_blocks.{i}.1' in key
]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {
'old': f'output_blocks.{i}.0',
'new': f'up_blocks.{block_id}.resnets.{layer_in_block_id}',
}
assign_to_checkpoint(
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
# オリジナル:
# if ["conv.weight", "conv.bias"] in output_block_list.values():
# index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
# biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
for l in output_block_list.values():
l.sort()
if ['conv.bias', 'conv.weight'] in output_block_list.values():
index = list(output_block_list.values()).index(
['conv.bias', 'conv.weight']
)
new_checkpoint[
f'up_blocks.{block_id}.upsamplers.0.conv.bias'
] = unet_state_dict[f'output_blocks.{i}.{index}.conv.bias']
new_checkpoint[
f'up_blocks.{block_id}.upsamplers.0.conv.weight'
] = unet_state_dict[f'output_blocks.{i}.{index}.conv.weight']
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
'old': f'output_blocks.{i}.1',
'new': f'up_blocks.{block_id}.attentions.{layer_in_block_id}',
}
assign_to_checkpoint(
paths,
new_checkpoint,
unet_state_dict,
additional_replacements=[meta_path],
config=config,
)
else:
resnet_0_paths = renew_resnet_paths(
output_block_layers, n_shave_prefix_segments=1
)
for path in resnet_0_paths:
old_path = '.'.join(['output_blocks', str(i), path['old']])
new_path = '.'.join(
[
'up_blocks',
str(block_id),
'resnets',
str(layer_in_block_id),
path['new'],
]
)
new_checkpoint[new_path] = unet_state_dict[old_path]
# SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
if v2:
linear_transformer_to_conv(new_checkpoint)
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = 'first_stage_model.'
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, '')] = checkpoint.get(key)
# if len(vae_state_dict) == 0:
# # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
# vae_state_dict = checkpoint
new_checkpoint = {}
new_checkpoint['encoder.conv_in.weight'] = vae_state_dict[
'encoder.conv_in.weight'
]
new_checkpoint['encoder.conv_in.bias'] = vae_state_dict[
'encoder.conv_in.bias'
]
new_checkpoint['encoder.conv_out.weight'] = vae_state_dict[
'encoder.conv_out.weight'
]
new_checkpoint['encoder.conv_out.bias'] = vae_state_dict[
'encoder.conv_out.bias'
]
new_checkpoint['encoder.conv_norm_out.weight'] = vae_state_dict[
'encoder.norm_out.weight'
]
new_checkpoint['encoder.conv_norm_out.bias'] = vae_state_dict[
'encoder.norm_out.bias'
]
new_checkpoint['decoder.conv_in.weight'] = vae_state_dict[
'decoder.conv_in.weight'
]
new_checkpoint['decoder.conv_in.bias'] = vae_state_dict[
'decoder.conv_in.bias'
]
new_checkpoint['decoder.conv_out.weight'] = vae_state_dict[
'decoder.conv_out.weight'
]
new_checkpoint['decoder.conv_out.bias'] = vae_state_dict[
'decoder.conv_out.bias'
]
new_checkpoint['decoder.conv_norm_out.weight'] = vae_state_dict[
'decoder.norm_out.weight'
]
new_checkpoint['decoder.conv_norm_out.bias'] = vae_state_dict[
'decoder.norm_out.bias'
]
new_checkpoint['quant_conv.weight'] = vae_state_dict['quant_conv.weight']
new_checkpoint['quant_conv.bias'] = vae_state_dict['quant_conv.bias']
new_checkpoint['post_quant_conv.weight'] = vae_state_dict[
'post_quant_conv.weight'
]
new_checkpoint['post_quant_conv.bias'] = vae_state_dict[
'post_quant_conv.bias'
]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len(
{
'.'.join(layer.split('.')[:3])
for layer in vae_state_dict
if 'encoder.down' in layer
}
)
down_blocks = {
layer_id: [key for key in vae_state_dict if f'down.{layer_id}' in key]
for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len(
{
'.'.join(layer.split('.')[:3])
for layer in vae_state_dict
if 'decoder.up' in layer
}
)
up_blocks = {
layer_id: [key for key in vae_state_dict if f'up.{layer_id}' in key]
for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [
key
for key in down_blocks[i]
if f'down.{i}' in key and f'down.{i}.downsample' not in key
]
if f'encoder.down.{i}.downsample.conv.weight' in vae_state_dict:
new_checkpoint[
f'encoder.down_blocks.{i}.downsamplers.0.conv.weight'
] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.weight')
new_checkpoint[
f'encoder.down_blocks.{i}.downsamplers.0.conv.bias'
] = vae_state_dict.pop(f'encoder.down.{i}.downsample.conv.bias')
paths = renew_vae_resnet_paths(resnets)
meta_path = {
'old': f'down.{i}.block',
'new': f'down_blocks.{i}.resnets',
}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_resnets = [key for key in vae_state_dict if 'encoder.mid.block' in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [
key for key in mid_resnets if f'encoder.mid.block_{i}' in key
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {
'old': f'mid.block_{i}',
'new': f'mid_block.resnets.{i - 1}',
}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_attentions = [
key for key in vae_state_dict if 'encoder.mid.attn' in key
]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key
for key in up_blocks[block_id]
if f'up.{block_id}' in key and f'up.{block_id}.upsample' not in key
]
if f'decoder.up.{block_id}.upsample.conv.weight' in vae_state_dict:
new_checkpoint[
f'decoder.up_blocks.{i}.upsamplers.0.conv.weight'
] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.weight']
new_checkpoint[
f'decoder.up_blocks.{i}.upsamplers.0.conv.bias'
] = vae_state_dict[f'decoder.up.{block_id}.upsample.conv.bias']
paths = renew_vae_resnet_paths(resnets)
meta_path = {
'old': f'up.{block_id}.block',
'new': f'up_blocks.{i}.resnets',
}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_resnets = [key for key in vae_state_dict if 'decoder.mid.block' in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [
key for key in mid_resnets if f'decoder.mid.block_{i}' in key
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {
'old': f'mid.block_{i}',
'new': f'mid_block.resnets.{i - 1}',
}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
mid_attentions = [
key for key in vae_state_dict if 'decoder.mid.attn' in key
]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {'old': 'mid.attn_1', 'new': 'mid_block.attentions.0'}
assign_to_checkpoint(
paths,
new_checkpoint,
vae_state_dict,
additional_replacements=[meta_path],
config=config,
)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
def create_unet_diffusers_config(v2):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
# unet_params = original_config.model.params.unet_config.params
block_out_channels = [
UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT
]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = (
'CrossAttnDownBlock2D'
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
else 'DownBlock2D'
)
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = (
'CrossAttnUpBlock2D'
if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS
else 'UpBlock2D'
)
up_block_types.append(block_type)
resolution //= 2
config = dict(
sample_size=UNET_PARAMS_IMAGE_SIZE,
in_channels=UNET_PARAMS_IN_CHANNELS,
out_channels=UNET_PARAMS_OUT_CHANNELS,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
cross_attention_dim=UNET_PARAMS_CONTEXT_DIM
if not v2
else V2_UNET_PARAMS_CONTEXT_DIM,
attention_head_dim=UNET_PARAMS_NUM_HEADS
if not v2
else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
)
return config
def create_vae_diffusers_config():
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
# vae_params = original_config.model.params.first_stage_config.params.ddconfig
# _ = original_config.model.params.first_stage_config.params.embed_dim
block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
down_block_types = ['DownEncoderBlock2D'] * len(block_out_channels)
up_block_types = ['UpDecoderBlock2D'] * len(block_out_channels)
config = dict(
sample_size=VAE_PARAMS_RESOLUTION,
in_channels=VAE_PARAMS_IN_CHANNELS,
out_channels=VAE_PARAMS_OUT_CH,
down_block_types=tuple(down_block_types),
up_block_types=tuple(up_block_types),
block_out_channels=tuple(block_out_channels),
latent_channels=VAE_PARAMS_Z_CHANNELS,
layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
)
return config
def convert_ldm_clip_checkpoint_v1(checkpoint):
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith('cond_stage_model.transformer'):
text_model_dict[
key[len('cond_stage_model.transformer.') :]
] = checkpoint[key]
return text_model_dict
def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
# 嫌になるくらい違うぞ!
def convert_key(key):
if not key.startswith('cond_stage_model'):
return None
# common conversion
key = key.replace(
'cond_stage_model.model.transformer.', 'text_model.encoder.'
)
key = key.replace('cond_stage_model.model.', 'text_model.')
if 'resblocks' in key:
# resblocks conversion
key = key.replace('.resblocks.', '.layers.')
if '.ln_' in key:
key = key.replace('.ln_', '.layer_norm')
elif '.mlp.' in key:
key = key.replace('.c_fc.', '.fc1.')
key = key.replace('.c_proj.', '.fc2.')
elif '.attn.out_proj' in key:
key = key.replace('.attn.out_proj.', '.self_attn.out_proj.')
elif '.attn.in_proj' in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f'unexpected key in SD: {key}')
elif '.positional_embedding' in key:
key = key.replace(
'.positional_embedding',
'.embeddings.position_embedding.weight',
)
elif '.text_projection' in key:
key = None # 使われない???
elif '.logit_scale' in key:
key = None # 使われない???
elif '.token_embedding' in key:
key = key.replace(
'.token_embedding.weight', '.embeddings.token_embedding.weight'
)
elif '.ln_final' in key:
key = key.replace('.ln_final', '.final_layer_norm')
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
# remove resblocks 23
if '.resblocks.23.' in key:
continue
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if '.resblocks.23.' in key:
continue
if '.resblocks' in key and '.attn.in_proj_' in key:
# 三つに分割
values = torch.chunk(checkpoint[key], 3)
key_suffix = '.weight' if 'weight' in key else '.bias'
key_pfx = key.replace(
'cond_stage_model.model.transformer.resblocks.',
'text_model.encoder.layers.',
)
key_pfx = key_pfx.replace('_weight', '')
key_pfx = key_pfx.replace('_bias', '')
key_pfx = key_pfx.replace('.attn.in_proj', '.self_attn.')
new_sd[key_pfx + 'q_proj' + key_suffix] = values[0]
new_sd[key_pfx + 'k_proj' + key_suffix] = values[1]
new_sd[key_pfx + 'v_proj' + key_suffix] = values[2]
# position_idsの追加
new_sd['text_model.embeddings.position_ids'] = torch.Tensor(
[list(range(max_length))]
).to(torch.int64)
return new_sd
# endregion
# region Diffusers->StableDiffusion の変換コード
# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
def conv_transformer_to_linear(checkpoint):
keys = list(checkpoint.keys())
tf_keys = ['proj_in.weight', 'proj_out.weight']
for key in keys:
if '.'.join(key.split('.')[-2:]) in tf_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
def convert_unet_state_dict_to_sd(v2, unet_state_dict):
unet_conversion_map = [
# (stable-diffusion, HF Diffusers)
('time_embed.0.weight', 'time_embedding.linear_1.weight'),
('time_embed.0.bias', 'time_embedding.linear_1.bias'),
('time_embed.2.weight', 'time_embedding.linear_2.weight'),
('time_embed.2.bias', 'time_embedding.linear_2.bias'),
('input_blocks.0.0.weight', 'conv_in.weight'),
('input_blocks.0.0.bias', 'conv_in.bias'),
('out.0.weight', 'conv_norm_out.weight'),
('out.0.bias', 'conv_norm_out.bias'),
('out.2.weight', 'conv_out.weight'),
('out.2.bias', 'conv_out.bias'),
]
unet_conversion_map_resnet = [
# (stable-diffusion, HF Diffusers)
('in_layers.0', 'norm1'),
('in_layers.2', 'conv1'),
('out_layers.0', 'norm2'),
('out_layers.3', 'conv2'),
('emb_layers.1', 'time_emb_proj'),
('skip_connection', 'conv_shortcut'),
]
unet_conversion_map_layer = []
for i in range(4):
# loop over downblocks/upblocks
for j in range(2):
# loop over resnets/attentions for downblocks
hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.'
sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.'
unet_conversion_map_layer.append(
(sd_down_res_prefix, hf_down_res_prefix)
)
if i < 3:
# no attention layers in down_blocks.3
hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.'
sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.'
unet_conversion_map_layer.append(
(sd_down_atn_prefix, hf_down_atn_prefix)
)
for j in range(3):
# loop over resnets/attentions for upblocks
hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.'
sd_up_res_prefix = f'output_blocks.{3*i + j}.0.'
unet_conversion_map_layer.append(
(sd_up_res_prefix, hf_up_res_prefix)
)
if i > 0:
# no attention layers in up_blocks.0
hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.'
sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.'
unet_conversion_map_layer.append(
(sd_up_atn_prefix, hf_up_atn_prefix)
)
if i < 3:
# no downsample in down_blocks.3
hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.'
sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.'
unet_conversion_map_layer.append(
(sd_downsample_prefix, hf_downsample_prefix)
)
# no upsample in up_blocks.3
hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.'
sd_upsample_prefix = (
f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.'
)
unet_conversion_map_layer.append(
(sd_upsample_prefix, hf_upsample_prefix)
)
hf_mid_atn_prefix = 'mid_block.attentions.0.'
sd_mid_atn_prefix = 'middle_block.1.'
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
for j in range(2):
hf_mid_res_prefix = f'mid_block.resnets.{j}.'
sd_mid_res_prefix = f'middle_block.{2*j}.'
unet_conversion_map_layer.append(
(sd_mid_res_prefix, hf_mid_res_prefix)
)
# buyer beware: this is a *brittle* function,
# and correct output requires that all of these pieces interact in
# the exact order in which I have arranged them.
mapping = {k: k for k in unet_state_dict.keys()}
for sd_name, hf_name in unet_conversion_map:
mapping[hf_name] = sd_name
for k, v in mapping.items():
if 'resnets' in k:
for sd_part, hf_part in unet_conversion_map_resnet:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
for sd_part, hf_part in unet_conversion_map_layer:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
if v2:
conv_transformer_to_linear(new_state_dict)
return new_state_dict
# ================#
# VAE Conversion #
# ================#
def reshape_weight_for_sd(w):
# convert HF linear weights to SD conv2d weights
return w.reshape(*w.shape, 1, 1)
def convert_vae_state_dict(vae_state_dict):
vae_conversion_map = [
# (stable-diffusion, HF Diffusers)
('nin_shortcut', 'conv_shortcut'),
('norm_out', 'conv_norm_out'),
('mid.attn_1.', 'mid_block.attentions.0.'),
]
for i in range(4):
# down_blocks have two resnets
for j in range(2):
hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.'
sd_down_prefix = f'encoder.down.{i}.block.{j}.'
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
if i < 3:
hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.'
sd_downsample_prefix = f'down.{i}.downsample.'
vae_conversion_map.append(
(sd_downsample_prefix, hf_downsample_prefix)
)
hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.'
sd_upsample_prefix = f'up.{3-i}.upsample.'
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
# up_blocks have three resnets
# also, up blocks in hf are numbered in reverse from sd
for j in range(3):
hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.'
sd_up_prefix = f'decoder.up.{3-i}.block.{j}.'
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
# this part accounts for mid blocks in both the encoder and the decoder
for i in range(2):
hf_mid_res_prefix = f'mid_block.resnets.{i}.'
sd_mid_res_prefix = f'mid.block_{i+1}.'
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
vae_conversion_map_attn = [
# (stable-diffusion, HF Diffusers)
('norm.', 'group_norm.'),
('q.', 'query.'),
('k.', 'key.'),
('v.', 'value.'),
('proj_out.', 'proj_attn.'),
]
mapping = {k: k for k in vae_state_dict.keys()}
for k, v in mapping.items():
for sd_part, hf_part in vae_conversion_map:
v = v.replace(hf_part, sd_part)
mapping[k] = v
for k, v in mapping.items():
if 'attentions' in k:
for sd_part, hf_part in vae_conversion_map_attn:
v = v.replace(hf_part, sd_part)
mapping[k] = v
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
weights_to_convert = ['q', 'k', 'v', 'proj_out']
for k, v in new_state_dict.items():
for weight_name in weights_to_convert:
if f'mid.attn_1.{weight_name}.weight' in k:
# print(f"Reshaping {k} for SD format")
new_state_dict[k] = reshape_weight_for_sd(v)
return new_state_dict
# endregion
# region 自作のモデル読み書きなど
def is_safetensors(path):
return os.path.splitext(path)[1].lower() == '.safetensors'
def load_checkpoint_with_text_encoder_conversion(ckpt_path):
# text encoderの格納形式が違うモデルに対応する ('text_model'がない)
TEXT_ENCODER_KEY_REPLACEMENTS = [
(
'cond_stage_model.transformer.embeddings.',
'cond_stage_model.transformer.text_model.embeddings.',
),
(
'cond_stage_model.transformer.encoder.',
'cond_stage_model.transformer.text_model.encoder.',
),
(
'cond_stage_model.transformer.final_layer_norm.',
'cond_stage_model.transformer.text_model.final_layer_norm.',
),
]
if is_safetensors(ckpt_path):
checkpoint = None
state_dict = load_file(ckpt_path, 'cpu')
else:
checkpoint = torch.load(ckpt_path, map_location='cpu')
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
checkpoint = None
key_reps = []
for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
for key in state_dict.keys():
if key.startswith(rep_from):
new_key = rep_to + key[len(rep_from) :]
key_reps.append((key, new_key))
for key, new_key in key_reps:
state_dict[new_key] = state_dict[key]
del state_dict[key]
return checkpoint, state_dict
# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
_, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
if dtype is not None:
for k, v in state_dict.items():
if type(v) is torch.Tensor:
state_dict[k] = v.to(dtype)
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(v2)
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
v2, state_dict, unet_config
)
unet = UNet2DConditionModel(**unet_config)
info = unet.load_state_dict(converted_unet_checkpoint)
print('loading u-net:', info)
# Convert the VAE model.
vae_config = create_vae_diffusers_config()
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
state_dict, vae_config
)
vae = AutoencoderKL(**vae_config)
info = vae.load_state_dict(converted_vae_checkpoint)
print('loadint vae:', info)
# convert text_model
if v2:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(
state_dict, 77
)
cfg = CLIPTextConfig(
vocab_size=49408,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=23,
num_attention_heads=16,
max_position_embeddings=77,
hidden_act='gelu',
layer_norm_eps=1e-05,
dropout=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
model_type='clip_text_model',
projection_dim=512,
torch_dtype='float32',
transformers_version='4.25.0.dev0',
)
text_model = CLIPTextModel._from_config(cfg)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
else:
converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(
state_dict
)
text_model = CLIPTextModel.from_pretrained(
'openai/clip-vit-large-patch14'
)
info = text_model.load_state_dict(converted_text_encoder_checkpoint)
print('loading text encoder:', info)
return text_model, vae, unet
def convert_text_encoder_state_dict_to_sd_v2(
checkpoint, make_dummy_weights=False
):
def convert_key(key):
# position_idsの除去
if '.position_ids' in key:
return None
# common
key = key.replace('text_model.encoder.', 'transformer.')
key = key.replace('text_model.', '')
if 'layers' in key:
# resblocks conversion
key = key.replace('.layers.', '.resblocks.')
if '.layer_norm' in key:
key = key.replace('.layer_norm', '.ln_')
elif '.mlp.' in key:
key = key.replace('.fc1.', '.c_fc.')
key = key.replace('.fc2.', '.c_proj.')
elif '.self_attn.out_proj' in key:
key = key.replace('.self_attn.out_proj.', '.attn.out_proj.')
elif '.self_attn.' in key:
key = None # 特殊なので後で処理する
else:
raise ValueError(f'unexpected key in DiffUsers model: {key}')
elif '.position_embedding' in key:
key = key.replace(
'embeddings.position_embedding.weight', 'positional_embedding'
)
elif '.token_embedding' in key:
key = key.replace(
'embeddings.token_embedding.weight', 'token_embedding.weight'
)
elif 'final_layer_norm' in key:
key = key.replace('final_layer_norm', 'ln_final')
return key
keys = list(checkpoint.keys())
new_sd = {}
for key in keys:
new_key = convert_key(key)
if new_key is None:
continue
new_sd[new_key] = checkpoint[key]
# attnの変換
for key in keys:
if 'layers' in key and 'q_proj' in key:
# 三つを結合
key_q = key
key_k = key.replace('q_proj', 'k_proj')
key_v = key.replace('q_proj', 'v_proj')
value_q = checkpoint[key_q]
value_k = checkpoint[key_k]
value_v = checkpoint[key_v]
value = torch.cat([value_q, value_k, value_v])
new_key = key.replace(
'text_model.encoder.layers.', 'transformer.resblocks.'
)
new_key = new_key.replace('.self_attn.q_proj.', '.attn.in_proj_')
new_sd[new_key] = value
# 最後の層などを捏造するか
if make_dummy_weights:
print(
'make dummy weights for resblock.23, text_projection and logit scale.'
)
keys = list(new_sd.keys())
for key in keys:
if key.startswith('transformer.resblocks.22.'):
new_sd[key.replace('.22.', '.23.')] = new_sd[
key
].clone() # copyしないとsafetensorsの保存で落ちる
# Diffusersに含まれない重みを作っておく
new_sd['text_projection'] = torch.ones(
(1024, 1024),
dtype=new_sd[keys[0]].dtype,
device=new_sd[keys[0]].device,
)
new_sd['logit_scale'] = torch.tensor(1)
return new_sd
def save_stable_diffusion_checkpoint(
v2,
output_file,
text_encoder,
unet,
ckpt_path,
epochs,
steps,
save_dtype=None,
vae=None,
):
if ckpt_path is not None:
# epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(
ckpt_path
)
if checkpoint is None: # safetensors または state_dictのckpt
checkpoint = {}
strict = False
else:
strict = True
if 'state_dict' in state_dict:
del state_dict['state_dict']
else:
# 新しく作る
assert (
vae is not None
), 'VAE is required to save a checkpoint without a given checkpoint'
checkpoint = {}
state_dict = {}
strict = False
def update_sd(prefix, sd):
for k, v in sd.items():
key = prefix + k
assert (
not strict or key in state_dict
), f'Illegal key in save SD: {key}'
if save_dtype is not None:
v = v.detach().clone().to('cpu').to(save_dtype)
state_dict[key] = v
# Convert the UNet model
unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
update_sd('model.diffusion_model.', unet_state_dict)
# Convert the text encoder model
if v2:
make_dummy = (
ckpt_path is None
) # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(
text_encoder.state_dict(), make_dummy
)
update_sd('cond_stage_model.model.', text_enc_dict)
else:
text_enc_dict = text_encoder.state_dict()
update_sd('cond_stage_model.transformer.', text_enc_dict)
# Convert the VAE
if vae is not None:
vae_dict = convert_vae_state_dict(vae.state_dict())
update_sd('first_stage_model.', vae_dict)
# Put together new checkpoint
key_count = len(state_dict.keys())
new_ckpt = {'state_dict': state_dict}
if 'epoch' in checkpoint:
epochs += checkpoint['epoch']
if 'global_step' in checkpoint:
steps += checkpoint['global_step']
new_ckpt['epoch'] = epochs
new_ckpt['global_step'] = steps
if is_safetensors(output_file):
# TODO Tensor以外のdictの値を削除したほうがいいか
save_file(state_dict, output_file)
else:
torch.save(new_ckpt, output_file)
return key_count
def save_diffusers_checkpoint(
v2,
output_dir,
text_encoder,
unet,
pretrained_model_name_or_path,
vae=None,
use_safetensors=False,
):
if pretrained_model_name_or_path is None:
# load default settings for v1/v2
if v2:
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2
else:
pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1
scheduler = DDIMScheduler.from_pretrained(
pretrained_model_name_or_path, subfolder='scheduler'
)
tokenizer = CLIPTokenizer.from_pretrained(
pretrained_model_name_or_path, subfolder='tokenizer'
)
if vae is None:
vae = AutoencoderKL.from_pretrained(
pretrained_model_name_or_path, subfolder='vae'
)
pipeline = StableDiffusionPipeline(
unet=unet,
text_encoder=text_encoder,
vae=vae,
scheduler=scheduler,
tokenizer=tokenizer,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=None,
)
pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors)
VAE_PREFIX = 'first_stage_model.'
def load_vae(vae_id, dtype):
print(f'load VAE: {vae_id}')
if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
# Diffusers local/remote
try:
vae = AutoencoderKL.from_pretrained(
vae_id, subfolder=None, torch_dtype=dtype
)
except EnvironmentError as e:
print(f'exception occurs in loading vae: {e}')
print("retry with subfolder='vae'")
vae = AutoencoderKL.from_pretrained(
vae_id, subfolder='vae', torch_dtype=dtype
)
return vae
# local
vae_config = create_vae_diffusers_config()
if vae_id.endswith('.bin'):
# SD 1.5 VAE on Huggingface
vae_sd = torch.load(vae_id, map_location='cpu')
converted_vae_checkpoint = vae_sd
else:
# StableDiffusion
vae_model = torch.load(vae_id, map_location='cpu')
vae_sd = vae_model['state_dict']
# vae only or full model
full_model = False
for vae_key in vae_sd:
if vae_key.startswith(VAE_PREFIX):
full_model = True
break
if not full_model:
sd = {}
for key, value in vae_sd.items():
sd[VAE_PREFIX + key] = value
vae_sd = sd
del sd
# Convert the VAE model.
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
vae_sd, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
def get_epoch_ckpt_name(use_safetensors, epoch):
return f'epoch-{epoch:06d}' + (
'.safetensors' if use_safetensors else '.ckpt'
)
def get_last_ckpt_name(use_safetensors):
return f'last' + ('.safetensors' if use_safetensors else '.ckpt')
# endregion
def make_bucket_resolutions(
max_reso, min_size=256, max_size=1024, divisible=64
):
max_width, max_height = max_reso
max_area = (max_width // divisible) * (max_height // divisible)
resos = set()
size = int(math.sqrt(max_area)) * divisible
resos.add((size, size))
size = min_size
while size <= max_size:
width = size
height = min(max_size, (max_area // (width // divisible)) * divisible)
resos.add((width, height))
resos.add((height, width))
# # make additional resos
# if width >= height and width - divisible >= min_size:
# resos.add((width - divisible, height))
# resos.add((height, width - divisible))
# if height >= width and height - divisible >= min_size:
# resos.add((width, height - divisible))
# resos.add((height - divisible, width))
size += divisible
resos = list(resos)
resos.sort()
aspect_ratios = [w / h for w, h in resos]
return resos, aspect_ratios
if __name__ == '__main__':
resos, aspect_ratios = make_bucket_resolutions((512, 768))
print(len(resos))
print(resos)
print(aspect_ratios)
ars = set()
for ar in aspect_ratios:
if ar in ars:
print('error! duplicate ar:', ar)
ars.add(ar)