subaqua commited on
Commit
b1ee2e9
1 Parent(s): 19686cd

Update as_safetensors+fp16.ipynb

Browse files
Files changed (1) hide show
  1. as_safetensors+fp16.ipynb +60 -20
as_safetensors+fp16.ipynb CHANGED
@@ -130,12 +130,12 @@
130
  "source": [
131
  "#@title <font size=\"-0\">モデルを変換</font>\n",
132
  "#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
133
- "#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される\n",
 
134
  "import os\n",
135
  "import glob\n",
136
  "import torch\n",
137
  "import safetensors.torch\n",
138
- "from functools import partial\n",
139
  "\n",
140
  "from sys import modules\n",
141
  "if \"huggingface_hub\" in modules:\n",
@@ -147,7 +147,7 @@
147
  "clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
148
  "uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
149
  "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
150
- "merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"Anything-V3.0.vae.pt\"] {allow-input: true}\n",
151
  "save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
152
  "#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
153
  "#@markdown 投稿しない場合は何も入力しない<br>\n",
@@ -158,8 +158,8 @@
158
  " \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
159
  " \"kl-f8-anime.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
160
  " \"kl-f8-anime2.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
161
- " \"Anything-V3.0.vae.pt\": \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0.vae.pt\"}\n",
162
- "if (merge_vae in vae_preset) & (not os.path.exists(merge_vae)):\n",
163
  " wget.download(vae_preset[merge_vae])\n",
164
  "\n",
165
  "def upload_to_hugging_face(file_name):\n",
@@ -191,25 +191,74 @@
191
  " return input_string[len(prefix):]\n",
192
  " return input_string\n",
193
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  "if models == \"\":\n",
195
- " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
196
  "else:\n",
197
  " models = [m.strip() for m in models.split(\",\")]\n",
198
  "\n",
199
- "for model in models:\n",
200
  " model_name, model_ext = os.path.splitext(model)\n",
 
 
 
201
  " if model_ext == \".yaml\":\n",
202
  " convert_yaml(model)\n",
203
- " elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
204
  " print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
205
  " else:\n",
206
- " load_model = lambda filename: partial(safetensors.torch.load_file, device=\"cpu\")(filename) if os.path.splitext(filename)[1] == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))(filename)\n",
207
- " save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
208
  " # convert model\n",
209
  " with torch.no_grad():\n",
210
- " weights = load_model(model)\n",
 
 
 
 
 
 
211
  " if \"state_dict\" in weights:\n",
212
  " weights = weights[\"state_dict\"]\n",
 
 
 
213
  " if pruning:\n",
214
  " model_name += \"-pruned\"\n",
215
  " for key in list(weights.keys()):\n",
@@ -275,15 +324,6 @@
275
  "id": "0SUK6Alv2ItS"
276
  }
277
  },
278
- {
279
- "cell_type": "markdown",
280
- "source": [
281
- "Hugging Faceに5GB以上のファイルを投稿する場合はメモリ消費量が約2倍になります"
282
- ],
283
- "metadata": {
284
- "id": "8KU7VgNnE0Fy"
285
- }
286
- },
287
  {
288
  "cell_type": "markdown",
289
  "source": [
 
130
  "source": [
131
  "#@title <font size=\"-0\">モデルを変換</font>\n",
132
  "#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
133
+ "#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される<br>\n",
134
+ "#@markdown `a.ckpt, -, b.safetensors`のような形式でモデルの引き算ができます\n",
135
  "import os\n",
136
  "import glob\n",
137
  "import torch\n",
138
  "import safetensors.torch\n",
 
139
  "\n",
140
  "from sys import modules\n",
141
  "if \"huggingface_hub\" in modules:\n",
 
147
  "clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
148
  "uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
149
  "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
150
+ "merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"anything-v4.0.vae.pt\"] {allow-input: true}\n",
151
  "save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
152
  "#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
153
  "#@markdown 投稿しない場合は何も入力しない<br>\n",
 
158
  " \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
159
  " \"kl-f8-anime.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
160
  " \"kl-f8-anime2.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
161
+ " \"anything-v4.0.vae.pt\": \"https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0.vae.pt\"}\n",
162
+ "if (merge_vae in vae_preset) and (not os.path.exists(merge_vae)):\n",
163
  " wget.download(vae_preset[merge_vae])\n",
164
  "\n",
165
  "def upload_to_hugging_face(file_name):\n",
 
191
  " return input_string[len(prefix):]\n",
192
  " return input_string\n",
193
  "\n",
194
+ "load_model = lambda m: safetensors.torch.load_file(m, device=\"cpu\") if os.path.splitext(m)[1] == \".safetensors\" else torch.load(m, map_location=torch.device(\"cpu\"))\n",
195
+ "save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
196
+ "\n",
197
+ "# --- def merge ---#\n",
198
+ "@torch.no_grad()\n",
199
+ "def merged(model_a, model_b, fab, fa, fb):\n",
200
+ " weights_a = load_model(model_a) if isinstance(model_a, str) else model_a\n",
201
+ " weights_b = load_model(model_b) if isinstance(model_b, str) else model_b\n",
202
+ " if \"state_dict\" in weights_a:\n",
203
+ " weights_a = weights_a[\"state_dict\"]\n",
204
+ " if \"state_dict\" in weights_b:\n",
205
+ " weights_b = weights_b[\"state_dict\"]\n",
206
+ " for key in list(weights_a.keys() or weights_b.keys()):\n",
207
+ " if isinstance(weights_a[key], dict):\n",
208
+ " del weights_a[key]\n",
209
+ " if isinstance(weights_b[key], dict):\n",
210
+ " del weights_b[key]\n",
211
+ " if key.startswith(\"model.\") or key.startswith(\"model_ema.\"):\n",
212
+ " if (key in weights_a) and (key in weights_b):\n",
213
+ " weights_a[key] = fab(weights_a[key], weights_b[key])\n",
214
+ " del weights_b[key]\n",
215
+ " elif key in weights_a:\n",
216
+ " weights_a[key] = fa(weights_a[key])\n",
217
+ " elif key in weights_b:\n",
218
+ " weights_a[key] = fb(weights_b[key])\n",
219
+ " del weights_b[key]\n",
220
+ " del weights_b\n",
221
+ " return weights_a\n",
222
+ "\n",
223
+ "def add(model_a, model_b):\n",
224
+ " return merged(model_a, model_b, lambda a, b: a + b, lambda a: a, lambda b: b)\n",
225
+ "\n",
226
+ "def difference(model_a, model_b):\n",
227
+ " return merged(model_a, model_b, lambda a, b: a - b, lambda a: a, lambda b: -b)\n",
228
+ "\n",
229
+ "def add_difference(model_a, model_b, model_c):\n",
230
+ " return add(model_a, difference(model_b, model_c))\n",
231
+ "# --- end merge ---#\n",
232
+ "\n",
233
  "if models == \"\":\n",
234
+ " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\") if not os.path.basename(m) in vae_preset]\n",
235
  "else:\n",
236
  " models = [m.strip() for m in models.split(\",\")]\n",
237
  "\n",
238
+ "for i, model in enumerate(models):\n",
239
  " model_name, model_ext = os.path.splitext(model)\n",
240
+ " # a.ckpt, - ,b.ckpt # - or b.ckpt\n",
241
+ " if (models[i] == \"-\") or (models[i - 1] == \"-\"):\n",
242
+ " continue\n",
243
  " if model_ext == \".yaml\":\n",
244
  " convert_yaml(model)\n",
245
+ " elif (model_ext != \".safetensors\") and (model_ext != \".ckpt\"):\n",
246
  " print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
247
  " else:\n",
 
 
248
  " # convert model\n",
249
  " with torch.no_grad():\n",
250
+ " # a.ckpt, - ,b.ckpt # a.ckpt\n",
251
+ " if (i < len(models) - 1) and (models[i + 1] == \"-\"):\n",
252
+ " weights = difference(model, models[i + 2])\n",
253
+ " model_name = f\"{model_name}-{os.path.splitext(models[i + 2])[0]}\"\n",
254
+ " # otherwise\n",
255
+ " else:\n",
256
+ " weights = load_model(model)\n",
257
  " if \"state_dict\" in weights:\n",
258
  " weights = weights[\"state_dict\"]\n",
259
+ " for key in list(weights.keys()):\n",
260
+ " if isinstance(weights[key], dict):\n",
261
+ " del weights[key] # to fix the broken model\n",
262
  " if pruning:\n",
263
  " model_name += \"-pruned\"\n",
264
  " for key in list(weights.keys()):\n",
 
324
  "id": "0SUK6Alv2ItS"
325
  }
326
  },
 
 
 
 
 
 
 
 
 
327
  {
328
  "cell_type": "markdown",
329
  "source": [