File size: 15,222 Bytes
afe0452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11eecbe
afe0452
 
 
 
 
 
 
 
11eecbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afe0452
 
 
 
 
 
 
 
 
 
 
19686cd
 
 
11eecbe
 
19686cd
 
afe0452
 
19686cd
 
afe0452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba8763c
afe0452
 
 
 
 
 
564ed29
d53acfd
afe0452
 
 
 
11eecbe
 
afe0452
 
 
 
 
 
 
 
 
b1ee2e9
 
afe0452
 
 
 
 
11eecbe
 
 
 
afe0452
564ed29
afe0452
290ab2d
 
564ed29
b1ee2e9
564ed29
11eecbe
 
8cadaf8
11eecbe
 
564ed29
 
 
 
b1ee2e9
 
564ed29
 
11eecbe
 
 
 
 
 
afe0452
 
 
 
 
 
 
 
 
 
 
 
 
11eecbe
 
afe0452
 
564ed29
 
 
 
 
 
b1ee2e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
afe0452
b1ee2e9
afe0452
 
 
b1ee2e9
afe0452
b1ee2e9
 
 
afe0452
 
b1ee2e9
afe0452
 
 
 
b1ee2e9
 
 
 
 
 
 
afe0452
 
b1ee2e9
 
 
564ed29
 
 
 
 
afe0452
564ed29
afe0452
 
290ab2d
 
 
 
 
564ed29
 
 
 
 
 
 
 
afe0452
 
11eecbe
 
 
8cadaf8
11eecbe
 
 
afe0452
 
 
 
 
 
564ed29
 
afe0452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11eecbe
afe0452
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n",
        "#### SD2.x系付属の.yamlも併せて変換します\n",
        "#### オプションでfp16として保存できます"
      ],
      "metadata": {
        "id": "fAIY_GORNEYa"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "最初に以下のコードを実行"
      ],
      "metadata": {
        "id": "OnuCk_wNLM_D"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install torch safetensors\n",
        "!pip install pytorch-lightning\n",
        "!pip install wget"
      ],
      "metadata": {
        "id": "pXr7oNJzwwgU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "Google Drive上のファイルを読み書きしたい場合は、以下のコードを実行"
      ],
      "metadata": {
        "id": "NsncqZOha2e0"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive\n",
        "drive.mount(\"/content/drive\")"
      ],
      "metadata": {
        "id": "liEiK8Iioscq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title <font size=\"-0\">変換したモデルをHugging Faceに投稿したい場合は、以下のコードを実行</font>\n",
        "#@markdown 1. [このページ](https://huggingface.co./settings/tokens)にアクセスしてNew tokenからName=適当, Role=writeでAccess Tokenを取得\n",
        "#@markdown 2. 取得したTokenをコピー & 以下の欄に貼り付け & 実行\n",
        "!pip install huggingface_hub\n",
        "from huggingface_hub import login\n",
        "token = \"\" #@param {type:\"string\"}\n",
        "login(token=token)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "mJO8RdvIINA-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行"
      ],
      "metadata": {
        "id": "7Ils-K70k15Y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title <font size=\"-0\">モデルをダウンロード</font>\n",
        "#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n",
        "#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n",
        "#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n",
        "import shutil\n",
        "import urllib.parse\n",
        "import urllib.request\n",
        "import wget\n",
        "import os\n",
        "\n",
        "models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
        "models = [m.strip() for m in models.split(\",\")]\n",
        "for model in models:\n",
        "  if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
        "    wget.download(model)\n",
        "  elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\", \".pt\")):\n",
        "    shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + os.path.basename(model)) # get the model from mydrive\n",
        "  else:\n",
        "    print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
      ],
      "metadata": {
        "id": "4vd3A09AxJE0",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title <font size=\"-0\">モデルを変換</font>\n",
        "#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
        "#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される<br>\n",
        "#@markdown `a.ckpt, -, b.safetensors`のような形式でモデルの引き算ができます\n",
        "import os\n",
        "import glob\n",
        "import torch\n",
        "import safetensors.torch\n",
        "\n",
        "from sys import modules\n",
        "if \"huggingface_hub\" in modules:\n",
        "  from huggingface_hub import HfApi, Repository\n",
        "\n",
        "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
        "pruning = True #@param {type:\"boolean\"}\n",
        "as_fp16 = True #@param {type:\"boolean\"}\n",
        "clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
        "uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
        "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
        "merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"anything-v4.0.vae.pt\"] {allow-input: true}\n",
        "save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
        "#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
        "#@markdown 投稿しない場合は何も入力しない<br>\n",
        "# 5GB以上のファイルを投稿する場合は、投稿先リポジトリを丸ごとダウンロードする工程が挟まるので、時間がかかる場合があります\n",
        "repo_id = \"\" #@param {type:\"string\"}\n",
        "\n",
        "vae_preset = {\n",
        "    \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co./stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
        "    \"kl-f8-anime.ckpt\": \"https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
        "    \"kl-f8-anime2.ckpt\": \"https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
        "    \"anything-v4.0.vae.pt\": \"https://huggingface.co./andite/anything-v4.0/resolve/main/anything-v4.0.vae.pt\"}\n",
        "if (merge_vae in vae_preset) and (not os.path.exists(merge_vae)):\n",
        "  wget.download(vae_preset[merge_vae])\n",
        "\n",
        "def upload_to_hugging_face(file_name):\n",
        "  api = HfApi()\n",
        "  api.upload_file(path_or_fileobj=file_name,\n",
        "    path_in_repo=file_name,\n",
        "    repo_id=repo_id,\n",
        "  )\n",
        "\n",
        "def convert_yaml(file_name):\n",
        "  with open(file_name) as f:\n",
        "    yaml = f.read()\n",
        "  if save_directly_to_Google_Drive:\n",
        "    os.chdir(\"/content/drive/MyDrive\")\n",
        "  is_safe = save_type == \".safetensors\"\n",
        "  yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
        "  if as_fp16:\n",
        "    yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
        "    file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
        "  with open(file_name, mode=\"w\") as f:\n",
        "    f.write(yaml)\n",
        "  if repo_id != \"\":\n",
        "    upload_to_hugging_face(file_name)\n",
        "  os.chdir(\"/content\")\n",
        "\n",
        "#use `str.removeprefix(p)` in python 3.9+\n",
        "def remove_prefix(input_string, prefix):\n",
        "  if prefix and input_string.startswith(prefix):\n",
        "    return input_string[len(prefix):]\n",
        "  return input_string\n",
        "\n",
        "load_model = lambda m: safetensors.torch.load_file(m, device=\"cpu\") if os.path.splitext(m)[1] == \".safetensors\" else torch.load(m, map_location=torch.device(\"cpu\"))\n",
        "save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
        "\n",
        "# --- def merge ---#\n",
        "@torch.no_grad()\n",
        "def merged(model_a, model_b, fab, fa, fb):\n",
        "  weights_a = load_model(model_a) if isinstance(model_a, str) else model_a\n",
        "  weights_b = load_model(model_b) if isinstance(model_b, str) else model_b\n",
        "  if \"state_dict\" in weights_a:\n",
        "    weights_a = weights_a[\"state_dict\"]\n",
        "  if \"state_dict\" in weights_b:\n",
        "    weights_b = weights_b[\"state_dict\"]\n",
        "  for key in list(weights_a.keys() or weights_b.keys()):\n",
        "    if isinstance(weights_a[key], dict):\n",
        "      del weights_a[key]\n",
        "    if isinstance(weights_b[key], dict):\n",
        "      del weights_b[key]\n",
        "    if key.startswith(\"model.\") or key.startswith(\"model_ema.\"):\n",
        "      if (key in weights_a) and (key in weights_b):\n",
        "        weights_a[key] = fab(weights_a[key], weights_b[key])\n",
        "        del weights_b[key]\n",
        "      elif key in weights_a:\n",
        "        weights_a[key] = fa(weights_a[key])\n",
        "      elif key in weights_b:\n",
        "        weights_a[key] = fb(weights_b[key])\n",
        "        del weights_b[key]\n",
        "  del weights_b\n",
        "  return weights_a\n",
        "\n",
        "def add(model_a, model_b):\n",
        "  return merged(model_a, model_b, lambda a, b: a + b, lambda a: a, lambda b: b)\n",
        "\n",
        "def difference(model_a, model_b):\n",
        "  return merged(model_a, model_b, lambda a, b: a - b, lambda a: a, lambda b: -b)\n",
        "\n",
        "def add_difference(model_a, model_b, model_c):\n",
        "  return add(model_a, difference(model_b, model_c))\n",
        "# --- end merge ---#\n",
        "\n",
        "if models == \"\":\n",
        "  models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\") if not os.path.basename(m) in vae_preset]\n",
        "else:\n",
        "  models = [m.strip() for m in models.split(\",\")]\n",
        "\n",
        "for i, model in enumerate(models):\n",
        "  model_name, model_ext = os.path.splitext(model)\n",
        "  # a.ckpt, - ,b.ckpt # - or b.ckpt\n",
        "  if (models[i] == \"-\") or (models[i - 1] == \"-\"):\n",
        "    continue\n",
        "  if model_ext == \".yaml\":\n",
        "    convert_yaml(model)\n",
        "  elif (model_ext != \".safetensors\") and (model_ext != \".ckpt\"):\n",
        "    print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
        "  else:\n",
        "    # convert model\n",
        "    with torch.no_grad():\n",
        "      # a.ckpt, - ,b.ckpt # a.ckpt\n",
        "      if (i < len(models) - 1) and (models[i + 1] == \"-\"):\n",
        "        weights = difference(model, models[i + 2])\n",
        "        model_name = f\"{model_name}-{os.path.splitext(models[i + 2])[0]}\"\n",
        "      # otherwise\n",
        "      else:\n",
        "        weights = load_model(model)\n",
        "      if \"state_dict\" in weights:\n",
        "        weights = weights[\"state_dict\"]\n",
        "      for key in list(weights.keys()):\n",
        "        if isinstance(weights[key], dict):\n",
        "          del weights[key] # to fix the broken model\n",
        "      if pruning:\n",
        "        model_name += \"-pruned\"\n",
        "        for key in list(weights.keys()):\n",
        "          if key.startswith(\"model_ema.\"):\n",
        "            del weights[key]\n",
        "      if as_fp16:\n",
        "        model_name += \"-fp16\"\n",
        "        for key in weights.keys():\n",
        "          weights[key] = weights[key].half()\n",
        "      if uninvited_key in weights:\n",
        "        if clip_fix == \"del err key\":\n",
        "          del weights[uninvited_key]\n",
        "        if clip_fix == \"fix err key\":\n",
        "          weights[uninvited_key] = torch.tensor([list(range(77))],dtype=torch.int64)\n",
        "      if merge_vae != \"\":\n",
        "        vae_weights = load_model(merge_vae)\n",
        "        if \"state_dict\" in vae_weights:\n",
        "          vae_weights = vae_weights[\"state_dict\"]\n",
        "        for key in weights.keys():\n",
        "          if key.startswith(\"first_stage_model.\"):\n",
        "            weights[key] = vae_weights[remove_prefix(key, \"first_stage_model.\")]\n",
        "        del vae_weights\n",
        "      if save_directly_to_Google_Drive:\n",
        "        os.chdir(\"/content/drive/MyDrive\")\n",
        "      save_model(weights, saved_model := model_name + save_type)\n",
        "      if repo_id != \"\":\n",
        "        if os.path.getsize(saved_model) >= 5*1000*1000*1000:\n",
        "          with Repository(os.path.basename(repo_id), clone_from=repo_id, skip_lfs_files=True, token=True).commit(commit_message=f\"Upload {saved_model} with huggingface_hub\", blocking=False):\n",
        "            save_model(weights, saved_model)\n",
        "        else:\n",
        "          upload_to_hugging_face(saved_model)\n",
        "      os.chdir(\"/content\")\n",
        "      del weights\n",
        "\n",
        "!reset"
      ],
      "metadata": {
        "cellView": "form",
        "id": "QSzZqGygdXM9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n",
        "\n",
        "指定方法はモデルと同じです"
      ],
      "metadata": {
        "id": "SWTFKmGFLec6"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n",
        "\n",
        "標準では10GBまでのモデルを変換できます"
      ],
      "metadata": {
        "id": "0SUK6Alv2ItS"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "[モデルのリンク集](https://huggingface.co./models?other=stable-diffusion)等から好きなモデルを選ぼう"
      ],
      "metadata": {
        "id": "yaLq5Nqe6an6"
      }
    }
  ]
}