2vXpSwA7 commited on
Commit
828547d
·
verified ·
1 Parent(s): 2e50127

Add differenceに対応したよ

Browse files
Files changed (1) hide show
  1. keybased_modelmerger.py +81 -30
keybased_modelmerger.py CHANGED
@@ -10,36 +10,57 @@ class KeyBasedModelMerger(scripts.Script):
10
  return "Key-based model merging"
11
 
12
  def ui(self, is_txt2img):
13
- # UI コンポーネントを定義
14
  model_names = sorted(sd_models.checkpoints_list.keys(), key=str.casefold)
15
-
16
  model_a_dropdown = gr.Dropdown(
17
  label="Model A", choices=model_names, value=model_names[0] if model_names else None
18
  )
19
  model_b_dropdown = gr.Dropdown(
20
  label="Model B", choices=model_names, value=model_names[0] if model_names else None
21
  )
 
 
 
22
  keys_and_alphas_textbox = gr.Textbox(
23
  label="マージするテンソルのキーとマージ比率 (部分一致, 1行に1つ, カンマ区切り)",
24
  lines=5,
25
  placeholder="例:\nmodel.diffusion_model.input_blocks.0,0.5\nmodel.diffusion_model.middle_block,0.3"
26
  )
27
  merge_checkbox = gr.Checkbox(label="モデルのマージを有効にする", value=True)
28
- use_gpu_checkbox = gr.Checkbox(label="GPUを使用", value=True) # GPU/CPU切り替えチェックボックス
29
  batch_size_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="KeyMgerge_BatchSize")
 
 
 
 
 
30
 
31
- return [model_a_dropdown, model_b_dropdown, keys_and_alphas_textbox, merge_checkbox, use_gpu_checkbox, batch_size_slider]
 
32
 
33
- def run(self, p, model_a_name, model_b_name, keys_and_alphas_str, merge_enabled, use_gpu, batch_size):
34
- if not model_a_name or not model_b_name:
35
- print("Error: Model A or Model B is not selected.")
 
36
  return p
37
 
38
  try:
39
- model_a_filename = sd_models.checkpoints_list[model_a_name].filename
40
- model_b_filename = sd_models.checkpoints_list[model_b_name].filename
 
 
 
 
 
 
 
 
 
 
 
 
41
  except KeyError as e:
42
- print(f"Error: Selected model is not found in checkpoints list. {e}")
43
  return p
44
 
45
  # マージ処理
@@ -52,11 +73,11 @@ class KeyBasedModelMerger(scripts.Script):
52
  alpha = float(alpha_str)
53
  input_keys_and_alphas.append((key_part, alpha))
54
  except ValueError:
55
- print(f"Invalid alpha value in line '{line}', skipping...")
56
-
57
  # state_dictからキーのリストを事前に作成
58
  model_keys = list(shared.sd_model.state_dict().keys())
59
-
60
  # 部分一致検索を行う
61
  final_keys_and_alphas = {}
62
  for key_part, alpha in input_keys_and_alphas:
@@ -70,24 +91,54 @@ class KeyBasedModelMerger(scripts.Script):
70
  # バッチ処理でキーをまとめて処理
71
  batched_keys = list(final_keys_and_alphas.items())
72
 
73
- # モデルAとモデルBからテンソルをまとめて取得
74
- with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
75
- safe_open(model_b_filename, framework="pt", device=device) as f_b:
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # バッチごとに処理
78
- for i in range(0, len(batched_keys), batch_size):
79
- batch = batched_keys[i:i + batch_size]
80
 
81
- # バッチでテンソルを取得して一度にマージ
82
- tensors_a = [f_a.get_tensor(key) for key, _ in batch]
83
- tensors_b = [f_b.get_tensor(key) for key, _ in batch]
84
- alphas = [final_keys_and_alphas[key] for key, _ in batch]
85
 
86
- # バッチでテンソルをマージして一度に適用
87
- for key, alpha, tensor_a, tensor_b in zip([key for key, _ in batch], alphas, tensors_a, tensors_b):
88
- # 直接 state_dict にマージ結果を適用
89
- shared.sd_model.state_dict()[key].copy_(torch.lerp(tensor_a, tensor_b, alpha).to(device))
90
- print(f"merged {alpha}:{key}")
91
 
92
- # 必要に応じて process_images を実行
93
- return process_images(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  return "Key-based model merging"
11
 
12
  def ui(self, is_txt2img):
 
13
  model_names = sorted(sd_models.checkpoints_list.keys(), key=str.casefold)
14
+
15
  model_a_dropdown = gr.Dropdown(
16
  label="Model A", choices=model_names, value=model_names[0] if model_names else None
17
  )
18
  model_b_dropdown = gr.Dropdown(
19
  label="Model B", choices=model_names, value=model_names[0] if model_names else None
20
  )
21
+ model_c_dropdown = gr.Dropdown(
22
+ label="Model C (Add difference mode用)", choices=model_names, value=model_names[0] if model_names else None
23
+ )
24
  keys_and_alphas_textbox = gr.Textbox(
25
  label="マージするテンソルのキーとマージ比率 (部分一致, 1行に1つ, カンマ区切り)",
26
  lines=5,
27
  placeholder="例:\nmodel.diffusion_model.input_blocks.0,0.5\nmodel.diffusion_model.middle_block,0.3"
28
  )
29
  merge_checkbox = gr.Checkbox(label="モデルのマージを有効にする", value=True)
30
+ use_gpu_checkbox = gr.Checkbox(label="GPUを使用", value=True)
31
  batch_size_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="KeyMgerge_BatchSize")
32
+ merge_mode_dropdown = gr.Dropdown(
33
+ label="Merge Mode",
34
+ choices=["Normal", "Add difference (B-C to Current)", "Add difference (A + (B-C) to Current)"],
35
+ value="Normal"
36
+ )
37
 
38
+ return [model_a_dropdown, model_b_dropdown, model_c_dropdown, keys_and_alphas_textbox,
39
+ merge_checkbox, use_gpu_checkbox, batch_size_slider, merge_mode_dropdown]
40
 
41
+ def run(self, p, model_a_name, model_b_name, model_c_name, keys_and_alphas_str,
42
+ merge_enabled, use_gpu, batch_size, merge_mode):
43
+ if not model_b_name:
44
+ print("Error: Model B is not selected.")
45
  return p
46
 
47
  try:
48
+ # 必要なモデルファイルだけを読み込む
49
+ if merge_mode == "Normal":
50
+ model_a_filename = sd_models.checkpoints_list[model_a_name].filename
51
+ model_b_filename = sd_models.checkpoints_list[model_b_name].filename
52
+ elif merge_mode == "Add difference (B-C to Current)":
53
+ model_b_filename = sd_models.checkpoints_list[model_b_name].filename
54
+ model_c_filename = sd_models.checkpoints_list[model_c_name].filename
55
+ elif merge_mode == "Add difference (A + (B-C) to Current)":
56
+ model_a_filename = sd_models.checkpoints_list[model_a_name].filename
57
+ model_b_filename = sd_models.checkpoints_list[model_b_name].filename
58
+ model_c_filename = sd_models.checkpoints_list[model_c_name].filename
59
+ else:
60
+ raise ValueError(f"Invalid merge mode: ")
61
+
62
  except KeyError as e:
63
+ print(f"Error: Selected model is not found in checkpoints list. ")
64
  return p
65
 
66
  # マージ処理
 
73
  alpha = float(alpha_str)
74
  input_keys_and_alphas.append((key_part, alpha))
75
  except ValueError:
76
+ print(f"Invalid alpha value in line '', skipping...")
77
+
78
  # state_dictからキーのリストを事前に作成
79
  model_keys = list(shared.sd_model.state_dict().keys())
80
+
81
  # 部分一致検索を行う
82
  final_keys_and_alphas = {}
83
  for key_part, alpha in input_keys_and_alphas:
 
91
  # バッチ処理でキーをまとめて処理
92
  batched_keys = list(final_keys_and_alphas.items())
93
 
94
+ # モデルファイルを開く
95
+ if merge_mode == "Normal":
96
+ with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
97
+ safe_open(model_b_filename, framework="pt", device=device) as f_b:
98
+ self._merge_models(f_a, f_b, None, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
99
+ elif merge_mode == "Add difference (B-C to Current)":
100
+ with safe_open(model_b_filename, framework="pt", device=device) as f_b, \
101
+ safe_open(model_c_filename, framework="pt", device=device) as f_c:
102
+ self._merge_models(None, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
103
+ elif merge_mode == "Add difference (A + (B-C) to Current)":
104
+ with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
105
+ safe_open(model_b_filename, framework="pt", device=device) as f_b, \
106
+ safe_open(model_c_filename, framework="pt", device=device) as f_c:
107
+ self._merge_models(f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
108
+ else:
109
+ raise ValueError(f"Invalid merge mode: ")
110
 
111
+ # 必要に応じて process_images を実行
112
+ return process_images(p)
 
113
 
114
+ def _merge_models(self, f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device):
115
+ # バッチごとに処理
116
+ for i in range(0, len(batched_keys), batch_size):
117
+ batch = batched_keys[i:i + batch_size]
118
 
119
+ # バッチでテンソルを取得
120
+ tensors_a = [f_a.get_tensor(key) for key, _ in batch] if f_a is not None else None
121
+ tensors_b = [f_b.get_tensor(key) for key, _ in batch] if f_b is not None else None
122
+ tensors_c = [f_c.get_tensor(key) for key, _ in batch] if f_c is not None else None
123
+ alphas = [final_keys_and_alphas[key] for key, _ in batch]
124
 
125
+ # マージ処理の実行
126
+ for j, (key, alpha) in enumerate(batch):
127
+ tensor_a = tensors_a[j] if tensors_a is not None else None
128
+ tensor_b = tensors_b[j] if tensors_b is not None else None
129
+ tensor_c = tensors_c[j] if tensors_c is not None else None
130
+
131
+ if merge_mode == "Normal":
132
+ merged_tensor = torch.lerp(tensor_a, tensor_b, alpha)
133
+ print(f"NomalMerged:{alpha}:{key}")
134
+ elif merge_mode == "Add difference (B-C to Current)":
135
+ merged_tensor = shared.sd_model.state_dict()[key] + alpha * (tensor_b - tensor_c)
136
+ print(f"(B-C to Current):{alpha}:{key}")
137
+ elif merge_mode == "Add difference (A + (B-C) to Current)":
138
+ merged_tensor = tensor_a + alpha * (tensor_b - tensor_c)
139
+ print(f"(A + (B-C) to Current):{alpha}:{key}")
140
+ else:
141
+ raise ValueError(f"Invalid merge mode: ")
142
+
143
+ shared.sd_model.state_dict()[key].copy_(merged_tensor.to(device))
144
+