File size: 7,692 Bytes
2e50127
 
 
 
 
 
 
 
 
 
 
 
 
828547d
2e50127
 
 
 
 
 
828547d
 
 
2e50127
 
 
 
 
 
828547d
2e50127
828547d
 
 
 
 
2e50127
828547d
 
2e50127
828547d
 
 
 
2e50127
 
 
828547d
 
 
 
 
 
 
 
 
 
 
 
 
 
2e50127
828547d
2e50127
 
 
 
 
 
 
 
 
 
 
 
828547d
 
2e50127
 
828547d
2e50127
 
 
 
 
 
 
 
 
 
 
 
 
828547d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e50127
828547d
 
2e50127
828547d
 
 
 
2e50127
828547d
 
 
 
 
2e50127
828547d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
from safetensors.torch import safe_open
from modules import scripts, sd_models, shared
import gradio as gr
from modules.processing import process_images


class KeyBasedModelMerger(scripts.Script):
    def title(self):
        return "Key-based model merging"

    def ui(self, is_txt2img):
        model_names = sorted(sd_models.checkpoints_list.keys(), key=str.casefold)

        model_a_dropdown = gr.Dropdown(
            label="Model A", choices=model_names, value=model_names[0] if model_names else None
        )
        model_b_dropdown = gr.Dropdown(
            label="Model B", choices=model_names, value=model_names[0] if model_names else None
        )
        model_c_dropdown = gr.Dropdown(
            label="Model C (Add difference mode用)", choices=model_names, value=model_names[0] if model_names else None
        )
        keys_and_alphas_textbox = gr.Textbox(
            label="マージするテンソルのキーとマージ比率 (部分一致, 1行に1つ, カンマ区切り)",
            lines=5,
            placeholder="例:\nmodel.diffusion_model.input_blocks.0,0.5\nmodel.diffusion_model.middle_block,0.3"
        )
        merge_checkbox = gr.Checkbox(label="モデルのマージを有効にする", value=True)
        use_gpu_checkbox = gr.Checkbox(label="GPUを使用", value=True)
        batch_size_slider = gr.Slider(minimum=1, maximum=500, step=1, value=250, label="KeyMgerge_BatchSize")
        merge_mode_dropdown = gr.Dropdown(
            label="Merge Mode",
            choices=["Normal", "Add difference (B-C to Current)", "Add difference (A + (B-C) to Current)"],
            value="Normal"
        )

        return [model_a_dropdown, model_b_dropdown, model_c_dropdown, keys_and_alphas_textbox,
                merge_checkbox, use_gpu_checkbox, batch_size_slider, merge_mode_dropdown]

    def run(self, p, model_a_name, model_b_name, model_c_name, keys_and_alphas_str,

            merge_enabled, use_gpu, batch_size, merge_mode):
        if not model_b_name:
            print("Error: Model B is not selected.")
            return p

        try:
            # 必要なモデルファイルだけを読み込む
            if merge_mode == "Normal":
                model_a_filename = sd_models.checkpoints_list[model_a_name].filename
                model_b_filename = sd_models.checkpoints_list[model_b_name].filename
            elif merge_mode == "Add difference (B-C to Current)":
                model_b_filename = sd_models.checkpoints_list[model_b_name].filename
                model_c_filename = sd_models.checkpoints_list[model_c_name].filename
            elif merge_mode == "Add difference (A + (B-C) to Current)":
                model_a_filename = sd_models.checkpoints_list[model_a_name].filename
                model_b_filename = sd_models.checkpoints_list[model_b_name].filename
                model_c_filename = sd_models.checkpoints_list[model_c_name].filename
            else:
                raise ValueError(f"Invalid merge mode: ")

        except KeyError as e:
            print(f"Error: Selected model is not found in checkpoints list. ")
            return p

        # マージ処理
        if merge_enabled:
            input_keys_and_alphas = []
            for line in keys_and_alphas_str.split("\n"):
                if "," in line:
                    key_part, alpha_str = line.split(",", 1)
                    try:
                        alpha = float(alpha_str)
                        input_keys_and_alphas.append((key_part, alpha))
                    except ValueError:
                        print(f"Invalid alpha value in line '', skipping...")

            # state_dictからキーのリストを事前に作成
            model_keys = list(shared.sd_model.state_dict().keys())

            # 部分一致検索を行う
            final_keys_and_alphas = {}
            for key_part, alpha in input_keys_and_alphas:
                for model_key in model_keys:
                    if key_part in model_key:
                        final_keys_and_alphas[model_key] = alpha

            # デバイスの設定 (GPUかCPUか選べるようにする)
            device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"

            # バッチ処理でキーをまとめて処理
            batched_keys = list(final_keys_and_alphas.items())

            # モデルファイルを開く
            if merge_mode == "Normal":
                with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
                     safe_open(model_b_filename, framework="pt", device=device) as f_b:
                    self._merge_models(f_a, f_b, None, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
            elif merge_mode == "Add difference (B-C to Current)":
                with safe_open(model_b_filename, framework="pt", device=device) as f_b, \
                     safe_open(model_c_filename, framework="pt", device=device) as f_c:
                    self._merge_models(None, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
            elif merge_mode == "Add difference (A + (B-C) to Current)":
                with safe_open(model_a_filename, framework="pt", device=device) as f_a, \
                     safe_open(model_b_filename, framework="pt", device=device) as f_b, \
                     safe_open(model_c_filename, framework="pt", device=device) as f_c:
                    self._merge_models(f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device)
            else:
                raise ValueError(f"Invalid merge mode: ")

        # 必要に応じて process_images を実行
        return process_images(p)

    def _merge_models(self, f_a, f_b, f_c, batched_keys, final_keys_and_alphas, batch_size, merge_mode, device):
        # バッチごとに処理
        for i in range(0, len(batched_keys), batch_size):
            batch = batched_keys[i:i + batch_size]

            # バッチでテンソルを取得
            tensors_a = [f_a.get_tensor(key) for key, _ in batch] if f_a is not None else None
            tensors_b = [f_b.get_tensor(key) for key, _ in batch] if f_b is not None else None
            tensors_c = [f_c.get_tensor(key) for key, _ in batch] if f_c is not None else None
            alphas = [final_keys_and_alphas[key] for key, _ in batch]

            # マージ処理の実行
            for j, (key, alpha) in enumerate(batch):
                tensor_a = tensors_a[j] if tensors_a is not None else None
                tensor_b = tensors_b[j] if tensors_b is not None else None
                tensor_c = tensors_c[j] if tensors_c is not None else None

                if merge_mode == "Normal":
                    merged_tensor = torch.lerp(tensor_a, tensor_b, alpha)
                    print(f"NomalMerged:{alpha}:{key}")
                elif merge_mode == "Add difference (B-C to Current)":
                    merged_tensor = shared.sd_model.state_dict()[key] + alpha * (tensor_b - tensor_c)
                    print(f"(B-C to Current):{alpha}:{key}")
                elif merge_mode == "Add difference (A + (B-C) to Current)":
                    merged_tensor = tensor_a + alpha * (tensor_b - tensor_c)
                    print(f"(A + (B-C) to Current):{alpha}:{key}")
                else:
                    raise ValueError(f"Invalid merge mode: ")

                shared.sd_model.state_dict()[key].copy_(merged_tensor.to(device))