File size: 7,692 Bytes
2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d 2e50127 828547d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import torch
from safetensors.torch import safe_open
from modules import scripts, sd_models, shared
import gradio as gr
from modules.processing import process_images
class KeyBasedModelMerger(scripts.Script):
def title(self):
return "Key-based model merging"
def ui(self, is_txt2img):
model_names = sorted(sd_models.checkpoints_list.keys(), key=str.casefold)
model_a_dropdown = gr.Dropdown(
label="Model A", choices=model_names, value=model_names[0] if model_names else None
)
model_b_dropdown = gr.Dropdown(
label="Model B", choices=model_names, value=model_names[0] if model_names else None
)
model_c_dropdown = gr.Dropdown(
label="Model C (Add difference mode用)", choices=model_names, value=model_names[0] if model_names else None
)
keys_and_alphas_textbox = gr.Textbox(
label="マージするテンソルのキーとマージ比率 (部分一致, 1行に1つ, カンマ区切り)",
lines=5,
placeholder="例:\nmodel.diffusion_model.input_blocks.0,0.5\nmodel.diffusion_model.middle_block,0.3"
)
merge_checkbox = gr.Checkbox(label="モデルのマージを有効にする", value=True)
use_gpu_checkbox = gr.Checkbox(label="GPUを使用", value=True)
batch_size_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="KeyMgerge_BatchSize")
merge_mode_dropdown = gr.Dropdown(
label="Merge Mode",
choices=["Normal", "Add difference (B-C to Current)", "Add difference (A + (B-C) to Current)"],
value="Normal"
)
return [model_a_dropdown, model_b_dropdown, model_c_dropdown, keys_and_alphas_textbox,
merge_checkbox, use_gpu_checkbox, batch_size_slider, merge_mode_dropdown]
def run(self, p, model_a_name, model_b_name, model_c_name, keys_and_alphas_str,
merge_enabled, use_gpu, batch_size, merge_mode):
if not model_b_name:
print("Error: Model B is not selected.")
return p
try:
# 必要なモデルファイルだけを読み込む
if merge_mode == "Normal":
model_a_filename = sd_models.checkpoints_list[model_a_name].filename
model_b_filename = sd_models.checkpoints_list[model_b_name].filename
elif merge_mode == "Add difference (B-C to Current)":
model_b_filename = sd_models.checkpoints_list[model_b_name].filename
model_c_filename = sd_models.checkpoints_list[model_c_name].filename
elif merge_mode == "Add difference (A + (B-C) to Current)":
model_a_filename = sd_models.checkpoints_list[model_a_name].filename
model_b_filename = sd_models.checkpoints_list[model_b_name].filename
model_c_filename = sd_models.checkpoints_list[model_c_name].filename
else:
raise ValueError(f"Invalid merge mode: ")
except KeyError as e:
print(f"Error: Selected model is not found in checkpoints list. ")
return p
# マージ処理
if merge_enabled:
input_keys_and_alphas = []
for line in keys_and_alphas_str.split("\n"):
if "," in line:
key_part, alpha_str = line.split(",", 1)
try:
alpha = float(alpha_str)
input_keys_and_alphas.append((key_part, alpha))
except ValueError:
print(f"Invalid alpha value in line '', skipping...")
# state_dictからキーのリストを事前に作成
model_keys = list(shared.sd_model.state_dict().keys())
# 部分一致検索を行う
final_keys_and_alphas = {}
for key_part, alpha in input_keys_and_alphas:
for model_key in model_keys:
if key_part in model_key:
final_keys_and_alphas[model_key] = alpha
# デバイスの設定 (GPUかCPUか選べるようにする)
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
# バッチ処理でキーをまとめて処理
batched_keys = list(final_keys_and_alphas.items())
# モデルファイルを開く
if merge_mode == "Normal":
with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
safe_open(model_b_filename, framework="pt", device=device) as f_b:
self._merge_models(f_a, f_b, None, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
elif merge_mode == "Add difference (B-C to Current)":
with safe_open(model_b_filename, framework="pt", device=device) as f_b, \
safe_open(model_c_filename, framework="pt", device=device) as f_c:
self._merge_models(None, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
elif merge_mode == "Add difference (A + (B-C) to Current)":
with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
safe_open(model_b_filename, framework="pt", device=device) as f_b, \
safe_open(model_c_filename, framework="pt", device=device) as f_c:
self._merge_models(f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
else:
raise ValueError(f"Invalid merge mode: ")
# 必要に応じて process_images を実行
return process_images(p)
def _merge_models(self, f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device):
# バッチごとに処理
for i in range(0, len(batched_keys), batch_size):
batch = batched_keys[i:i + batch_size]
# バッチでテンソルを取得
tensors_a = [f_a.get_tensor(key) for key, _ in batch] if f_a is not None else None
tensors_b = [f_b.get_tensor(key) for key, _ in batch] if f_b is not None else None
tensors_c = [f_c.get_tensor(key) for key, _ in batch] if f_c is not None else None
alphas = [final_keys_and_alphas[key] for key, _ in batch]
# マージ処理の実行
for j, (key, alpha) in enumerate(batch):
tensor_a = tensors_a[j] if tensors_a is not None else None
tensor_b = tensors_b[j] if tensors_b is not None else None
tensor_c = tensors_c[j] if tensors_c is not None else None
if merge_mode == "Normal":
merged_tensor = torch.lerp(tensor_a, tensor_b, alpha)
print(f"NomalMerged:{alpha}:{key}")
elif merge_mode == "Add difference (B-C to Current)":
merged_tensor = shared.sd_model.state_dict()[key] + alpha * (tensor_b - tensor_c)
print(f"(B-C to Current):{alpha}:{key}")
elif merge_mode == "Add difference (A + (B-C) to Current)":
merged_tensor = tensor_a + alpha * (tensor_b - tensor_c)
print(f"(A + (B-C) to Current):{alpha}:{key}")
else:
raise ValueError(f"Invalid merge mode: ")
shared.sd_model.state_dict()[key].copy_(merged_tensor.to(device))
|