Update as_safetensors+fp16.ipynb
Browse files- as_safetensors+fp16.ipynb +45 -7
as_safetensors+fp16.ipynb
CHANGED
@@ -121,7 +121,7 @@
|
|
121 |
"for model in models:\n",
|
122 |
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
|
123 |
" wget.download(model)\n",
|
124 |
-
" elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\")):\n",
|
125 |
" shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
|
126 |
" else:\n",
|
127 |
" print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
|
@@ -150,16 +150,26 @@
|
|
150 |
" from huggingface_hub import HfApi, Repository\n",
|
151 |
"\n",
|
152 |
"models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
|
|
|
153 |
"as_fp16 = True #@param {type:\"boolean\"}\n",
|
154 |
-
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
|
155 |
-
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
|
156 |
"clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
|
157 |
"uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
|
|
|
|
|
|
|
158 |
"#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
|
159 |
"#@markdown 投稿しない場合は何も入力しない<br>\n",
|
160 |
"# 5GB以上のファイルを投稿する場合は、投稿先リポジトリを丸ごとダウンロードする工程が挟まるので、時間がかかる場合があります\n",
|
161 |
"repo_id = \"\" #@param {type:\"string\"}\n",
|
162 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
"def upload_to_hugging_face(file_name):\n",
|
164 |
" api = HfApi()\n",
|
165 |
" api.upload_file(path_or_fileobj=file_name,\n",
|
@@ -183,6 +193,12 @@
|
|
183 |
" upload_to_hugging_face(file_name)\n",
|
184 |
" os.chdir(\"/content\")\n",
|
185 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
"if models == \"\":\n",
|
187 |
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
|
188 |
"else:\n",
|
@@ -195,15 +211,20 @@
|
|
195 |
" elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
|
196 |
" print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
|
197 |
" else:\n",
|
198 |
-
" load_model = partial(safetensors.torch.load_file, device=\"cpu\") if
|
199 |
" save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
|
200 |
" # convert model\n",
|
201 |
" with torch.no_grad():\n",
|
202 |
" weights = load_model(model)\n",
|
203 |
" if \"state_dict\" in weights:\n",
|
204 |
" weights = weights[\"state_dict\"]\n",
|
|
|
|
|
|
|
|
|
|
|
205 |
" if as_fp16:\n",
|
206 |
-
" model_name
|
207 |
" for key in weights.keys():\n",
|
208 |
" weights[key] = weights[key].half()\n",
|
209 |
" if uninvited_key in weights:\n",
|
@@ -211,6 +232,14 @@
|
|
211 |
" del weights[uninvited_key]\n",
|
212 |
" if clip_fix == \"fix err key\":\n",
|
213 |
" weights[uninvited_key] = torch.tensor([list(range(77))],dtype=torch.int64)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
" if save_directly_to_Google_Drive:\n",
|
215 |
" os.chdir(\"/content/drive/MyDrive\")\n",
|
216 |
" save_model(weights, saved_model := model_name + save_type)\n",
|
@@ -226,8 +255,8 @@
|
|
226 |
"!reset"
|
227 |
],
|
228 |
"metadata": {
|
229 |
-
"
|
230 |
-
"
|
231 |
},
|
232 |
"execution_count": null,
|
233 |
"outputs": []
|
@@ -254,6 +283,15 @@
|
|
254 |
"id": "0SUK6Alv2ItS"
|
255 |
}
|
256 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
{
|
258 |
"cell_type": "markdown",
|
259 |
"source": [
|
|
|
121 |
"for model in models:\n",
|
122 |
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
|
123 |
" wget.download(model)\n",
|
124 |
+
" elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\", \".pt\")):\n",
|
125 |
" shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
|
126 |
" else:\n",
|
127 |
" print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
|
|
|
150 |
" from huggingface_hub import HfApi, Repository\n",
|
151 |
"\n",
|
152 |
"models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
|
153 |
+
"pruning = True #@param {type:\"boolean\"}\n",
|
154 |
"as_fp16 = True #@param {type:\"boolean\"}\n",
|
|
|
|
|
155 |
"clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
|
156 |
"uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
|
157 |
+
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
|
158 |
+
"merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"Anything-V3.0.vae.pt\"] {allow-input: true}\n",
|
159 |
+
"save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
|
160 |
"#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
|
161 |
"#@markdown 投稿しない場合は何も入力しない<br>\n",
|
162 |
"# 5GB以上のファイルを投稿する場合は、投稿先リポジトリを丸ごとダウンロードする工程が挟まるので、時間がかかる場合があります\n",
|
163 |
"repo_id = \"\" #@param {type:\"string\"}\n",
|
164 |
"\n",
|
165 |
+
"vae_preset = {\n",
|
166 |
+
" \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
|
167 |
+
" \"kl-f8-anime.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
|
168 |
+
" \"kl-f8-anime2.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
|
169 |
+
" \"Anything-V3.0.vae.pt\": \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0.vae.pt\"}\n",
|
170 |
+
"if (merge_vae in vae_preset) & (not os.path.exists(merge_vae)):\n",
|
171 |
+
" wget.download(vae_preset[merge_vae])\n",
|
172 |
+
"\n",
|
173 |
"def upload_to_hugging_face(file_name):\n",
|
174 |
" api = HfApi()\n",
|
175 |
" api.upload_file(path_or_fileobj=file_name,\n",
|
|
|
193 |
" upload_to_hugging_face(file_name)\n",
|
194 |
" os.chdir(\"/content\")\n",
|
195 |
"\n",
|
196 |
+
"#use `str.removeprefix(p)` in python 3.9+\n",
|
197 |
+
"def remove_prefix(input_string, prefix):\n",
|
198 |
+
" if prefix and input_string.startswith(prefix):\n",
|
199 |
+
" return input_string[len(prefix):]\n",
|
200 |
+
" return input_string\n",
|
201 |
+
"\n",
|
202 |
"if models == \"\":\n",
|
203 |
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
|
204 |
"else:\n",
|
|
|
211 |
" elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
|
212 |
" print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
|
213 |
" else:\n",
|
214 |
+
" load_model = lambda filename: partial(safetensors.torch.load_file, device=\"cpu\")(filename) if os.path.splitext(filename)[1] == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))(filename)\n",
|
215 |
" save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
|
216 |
" # convert model\n",
|
217 |
" with torch.no_grad():\n",
|
218 |
" weights = load_model(model)\n",
|
219 |
" if \"state_dict\" in weights:\n",
|
220 |
" weights = weights[\"state_dict\"]\n",
|
221 |
+
" if pruning:\n",
|
222 |
+
" model_name += \"-pruned\"\n",
|
223 |
+
" for key in list(weights.keys()):\n",
|
224 |
+
" if key.startswith(\"model_ema.\"):\n",
|
225 |
+
" del weights[key]\n",
|
226 |
" if as_fp16:\n",
|
227 |
+
" model_name += \"-fp16\"\n",
|
228 |
" for key in weights.keys():\n",
|
229 |
" weights[key] = weights[key].half()\n",
|
230 |
" if uninvited_key in weights:\n",
|
|
|
232 |
" del weights[uninvited_key]\n",
|
233 |
" if clip_fix == \"fix err key\":\n",
|
234 |
" weights[uninvited_key] = torch.tensor([list(range(77))],dtype=torch.int64)\n",
|
235 |
+
" if merge_vae != \"\":\n",
|
236 |
+
" vae_weights = load_model(merge_vae)\n",
|
237 |
+
" if \"state_dict\" in vae_weights:\n",
|
238 |
+
" vae_weights = vae_weights[\"state_dict\"]\n",
|
239 |
+
" for key in weights.keys():\n",
|
240 |
+
" if key.startswith(\"first_stage_model.\"):\n",
|
241 |
+
" weights[key] = vae_weights[remove_prefix(key, \"first_stage_model.\")]\n",
|
242 |
+
" del vae_weights\n",
|
243 |
" if save_directly_to_Google_Drive:\n",
|
244 |
" os.chdir(\"/content/drive/MyDrive\")\n",
|
245 |
" save_model(weights, saved_model := model_name + save_type)\n",
|
|
|
255 |
"!reset"
|
256 |
],
|
257 |
"metadata": {
|
258 |
+
"cellView": "form",
|
259 |
+
"id": "QSzZqGygdXM9"
|
260 |
},
|
261 |
"execution_count": null,
|
262 |
"outputs": []
|
|
|
283 |
"id": "0SUK6Alv2ItS"
|
284 |
}
|
285 |
},
|
286 |
+
{
|
287 |
+
"cell_type": "markdown",
|
288 |
+
"source": [
|
289 |
+
"Hugging Faceに5GB以上のファイルを投稿する場合はメモリ消費量が約2倍になります"
|
290 |
+
],
|
291 |
+
"metadata": {
|
292 |
+
"id": "8KU7VgNnE0Fy"
|
293 |
+
}
|
294 |
+
},
|
295 |
{
|
296 |
"cell_type": "markdown",
|
297 |
"source": [
|