subaqua commited on
Commit
564ed29
1 Parent(s): 290ab2d

Update as_safetensors+fp16.ipynb

Browse files
Files changed (1) hide show
  1. as_safetensors+fp16.ipynb +45 -7
as_safetensors+fp16.ipynb CHANGED
@@ -121,7 +121,7 @@
121
  "for model in models:\n",
122
  " if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
123
  " wget.download(model)\n",
124
- " elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\")):\n",
125
  " shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
126
  " else:\n",
127
  " print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
@@ -150,16 +150,26 @@
150
  " from huggingface_hub import HfApi, Repository\n",
151
  "\n",
152
  "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
 
153
  "as_fp16 = True #@param {type:\"boolean\"}\n",
154
- "save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
155
- "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
156
  "clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
157
  "uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
 
 
 
158
  "#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
159
  "#@markdown 投稿しない場合は何も入力しない<br>\n",
160
  "# 5GB以上のファイルを投稿する場合は、投稿先リポジトリを丸ごとダウンロードする工程が挟まるので、時間がかかる場合があります\n",
161
  "repo_id = \"\" #@param {type:\"string\"}\n",
162
  "\n",
 
 
 
 
 
 
 
 
163
  "def upload_to_hugging_face(file_name):\n",
164
  " api = HfApi()\n",
165
  " api.upload_file(path_or_fileobj=file_name,\n",
@@ -183,6 +193,12 @@
183
  " upload_to_hugging_face(file_name)\n",
184
  " os.chdir(\"/content\")\n",
185
  "\n",
 
 
 
 
 
 
186
  "if models == \"\":\n",
187
  " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
188
  "else:\n",
@@ -195,15 +211,20 @@
195
  " elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
196
  " print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
197
  " else:\n",
198
- " load_model = partial(safetensors.torch.load_file, device=\"cpu\") if model_ext == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))\n",
199
  " save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
200
  " # convert model\n",
201
  " with torch.no_grad():\n",
202
  " weights = load_model(model)\n",
203
  " if \"state_dict\" in weights:\n",
204
  " weights = weights[\"state_dict\"]\n",
 
 
 
 
 
205
  " if as_fp16:\n",
206
- " model_name = model_name + \"-fp16\"\n",
207
  " for key in weights.keys():\n",
208
  " weights[key] = weights[key].half()\n",
209
  " if uninvited_key in weights:\n",
@@ -211,6 +232,14 @@
211
  " del weights[uninvited_key]\n",
212
  " if clip_fix == \"fix err key\":\n",
213
  " weights[uninvited_key] = torch.tensor([list(range(77))],dtype=torch.int64)\n",
 
 
 
 
 
 
 
 
214
  " if save_directly_to_Google_Drive:\n",
215
  " os.chdir(\"/content/drive/MyDrive\")\n",
216
  " save_model(weights, saved_model := model_name + save_type)\n",
@@ -226,8 +255,8 @@
226
  "!reset"
227
  ],
228
  "metadata": {
229
- "id": "9OmSG98HxJg2",
230
- "cellView": "form"
231
  },
232
  "execution_count": null,
233
  "outputs": []
@@ -254,6 +283,15 @@
254
  "id": "0SUK6Alv2ItS"
255
  }
256
  },
 
 
 
 
 
 
 
 
 
257
  {
258
  "cell_type": "markdown",
259
  "source": [
 
121
  "for model in models:\n",
122
  " if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
123
  " wget.download(model)\n",
124
+ " elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\", \".pt\")):\n",
125
  " shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
126
  " else:\n",
127
  " print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
 
150
  " from huggingface_hub import HfApi, Repository\n",
151
  "\n",
152
  "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
153
+ "pruning = True #@param {type:\"boolean\"}\n",
154
  "as_fp16 = True #@param {type:\"boolean\"}\n",
 
 
155
  "clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
156
  "uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
157
+ "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
158
+ "merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"Anything-V3.0.vae.pt\"] {allow-input: true}\n",
159
+ "save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
160
  "#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
161
  "#@markdown 投稿しない場合は何も入力しない<br>\n",
162
  "# 5GB以上のファイルを投稿する場合は、投稿先リポジトリを丸ごとダウンロードする工程が挟まるので、時間がかかる場合があります\n",
163
  "repo_id = \"\" #@param {type:\"string\"}\n",
164
  "\n",
165
+ "vae_preset = {\n",
166
+ " \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
167
+ " \"kl-f8-anime.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
168
+ " \"kl-f8-anime2.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
169
+ " \"Anything-V3.0.vae.pt\": \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0.vae.pt\"}\n",
170
+ "if (merge_vae in vae_preset) & (not os.path.exists(merge_vae)):\n",
171
+ " wget.download(vae_preset[merge_vae])\n",
172
+ "\n",
173
  "def upload_to_hugging_face(file_name):\n",
174
  " api = HfApi()\n",
175
  " api.upload_file(path_or_fileobj=file_name,\n",
 
193
  " upload_to_hugging_face(file_name)\n",
194
  " os.chdir(\"/content\")\n",
195
  "\n",
196
+ "#use `str.removeprefix(p)` in python 3.9+\n",
197
+ "def remove_prefix(input_string, prefix):\n",
198
+ " if prefix and input_string.startswith(prefix):\n",
199
+ " return input_string[len(prefix):]\n",
200
+ " return input_string\n",
201
+ "\n",
202
  "if models == \"\":\n",
203
  " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
204
  "else:\n",
 
211
  " elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
212
  " print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
213
  " else:\n",
214
+ " load_model = lambda filename: partial(safetensors.torch.load_file, device=\"cpu\")(filename) if os.path.splitext(filename)[1] == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))(filename)\n",
215
  " save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
216
  " # convert model\n",
217
  " with torch.no_grad():\n",
218
  " weights = load_model(model)\n",
219
  " if \"state_dict\" in weights:\n",
220
  " weights = weights[\"state_dict\"]\n",
221
+ " if pruning:\n",
222
+ " model_name += \"-pruned\"\n",
223
+ " for key in list(weights.keys()):\n",
224
+ " if key.startswith(\"model_ema.\"):\n",
225
+ " del weights[key]\n",
226
  " if as_fp16:\n",
227
+ " model_name += \"-fp16\"\n",
228
  " for key in weights.keys():\n",
229
  " weights[key] = weights[key].half()\n",
230
  " if uninvited_key in weights:\n",
 
232
  " del weights[uninvited_key]\n",
233
  " if clip_fix == \"fix err key\":\n",
234
  " weights[uninvited_key] = torch.tensor([list(range(77))],dtype=torch.int64)\n",
235
+ " if merge_vae != \"\":\n",
236
+ " vae_weights = load_model(merge_vae)\n",
237
+ " if \"state_dict\" in vae_weights:\n",
238
+ " vae_weights = vae_weights[\"state_dict\"]\n",
239
+ " for key in weights.keys():\n",
240
+ " if key.startswith(\"first_stage_model.\"):\n",
241
+ " weights[key] = vae_weights[remove_prefix(key, \"first_stage_model.\")]\n",
242
+ " del vae_weights\n",
243
  " if save_directly_to_Google_Drive:\n",
244
  " os.chdir(\"/content/drive/MyDrive\")\n",
245
  " save_model(weights, saved_model := model_name + save_type)\n",
 
255
  "!reset"
256
  ],
257
  "metadata": {
258
+ "cellView": "form",
259
+ "id": "QSzZqGygdXM9"
260
  },
261
  "execution_count": null,
262
  "outputs": []
 
283
  "id": "0SUK6Alv2ItS"
284
  }
285
  },
286
+ {
287
+ "cell_type": "markdown",
288
+ "source": [
289
+ "Hugging Faceに5GB以上のファイルを投稿する場合はメモリ消費量が約2倍になります"
290
+ ],
291
+ "metadata": {
292
+ "id": "8KU7VgNnE0Fy"
293
+ }
294
+ },
295
  {
296
  "cell_type": "markdown",
297
  "source": [