Update as_safetensors+fp16.ipynb
Browse files- as_safetensors+fp16.ipynb +60 -20
as_safetensors+fp16.ipynb
CHANGED
@@ -130,12 +130,12 @@
|
|
130 |
"source": [
|
131 |
"#@title <font size=\"-0\">モデルを変換</font>\n",
|
132 |
"#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
|
133 |
-
"#@markdown
|
|
|
134 |
"import os\n",
|
135 |
"import glob\n",
|
136 |
"import torch\n",
|
137 |
"import safetensors.torch\n",
|
138 |
-
"from functools import partial\n",
|
139 |
"\n",
|
140 |
"from sys import modules\n",
|
141 |
"if \"huggingface_hub\" in modules:\n",
|
@@ -147,7 +147,7 @@
|
|
147 |
"clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
|
148 |
"uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
|
149 |
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
|
150 |
-
"merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"
|
151 |
"save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
|
152 |
"#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
|
153 |
"#@markdown 投稿しない場合は何も入力しない<br>\n",
|
@@ -158,8 +158,8 @@
|
|
158 |
" \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
|
159 |
" \"kl-f8-anime.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
|
160 |
" \"kl-f8-anime2.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
|
161 |
-
" \"
|
162 |
-
"if (merge_vae in vae_preset)
|
163 |
" wget.download(vae_preset[merge_vae])\n",
|
164 |
"\n",
|
165 |
"def upload_to_hugging_face(file_name):\n",
|
@@ -191,25 +191,74 @@
|
|
191 |
" return input_string[len(prefix):]\n",
|
192 |
" return input_string\n",
|
193 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
"if models == \"\":\n",
|
195 |
-
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
|
196 |
"else:\n",
|
197 |
" models = [m.strip() for m in models.split(\",\")]\n",
|
198 |
"\n",
|
199 |
-
"for model in models:\n",
|
200 |
" model_name, model_ext = os.path.splitext(model)\n",
|
|
|
|
|
|
|
201 |
" if model_ext == \".yaml\":\n",
|
202 |
" convert_yaml(model)\n",
|
203 |
-
" elif (model_ext != \".safetensors\")
|
204 |
" print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
|
205 |
" else:\n",
|
206 |
-
" load_model = lambda filename: partial(safetensors.torch.load_file, device=\"cpu\")(filename) if os.path.splitext(filename)[1] == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))(filename)\n",
|
207 |
-
" save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
|
208 |
" # convert model\n",
|
209 |
" with torch.no_grad():\n",
|
210 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
" if \"state_dict\" in weights:\n",
|
212 |
" weights = weights[\"state_dict\"]\n",
|
|
|
|
|
|
|
213 |
" if pruning:\n",
|
214 |
" model_name += \"-pruned\"\n",
|
215 |
" for key in list(weights.keys()):\n",
|
@@ -275,15 +324,6 @@
|
|
275 |
"id": "0SUK6Alv2ItS"
|
276 |
}
|
277 |
},
|
278 |
-
{
|
279 |
-
"cell_type": "markdown",
|
280 |
-
"source": [
|
281 |
-
"Hugging Faceに5GB以上のファイルを投稿する場合はメモリ消費量が約2倍になります"
|
282 |
-
],
|
283 |
-
"metadata": {
|
284 |
-
"id": "8KU7VgNnE0Fy"
|
285 |
-
}
|
286 |
-
},
|
287 |
{
|
288 |
"cell_type": "markdown",
|
289 |
"source": [
|
|
|
130 |
"source": [
|
131 |
"#@title <font size=\"-0\">モデルを変換</font>\n",
|
132 |
"#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
|
133 |
+
"#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される<br>\n",
|
134 |
+
"#@markdown `a.ckpt, -, b.safetensors`のような形式でモデルの引き算ができます\n",
|
135 |
"import os\n",
|
136 |
"import glob\n",
|
137 |
"import torch\n",
|
138 |
"import safetensors.torch\n",
|
|
|
139 |
"\n",
|
140 |
"from sys import modules\n",
|
141 |
"if \"huggingface_hub\" in modules:\n",
|
|
|
147 |
"clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n",
|
148 |
"uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n",
|
149 |
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
|
150 |
+
"merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"anything-v4.0.vae.pt\"] {allow-input: true}\n",
|
151 |
"save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n",
|
152 |
"#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定<br>\n",
|
153 |
"#@markdown 投稿しない場合は何も入力しない<br>\n",
|
|
|
158 |
" \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n",
|
159 |
" \"kl-f8-anime.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n",
|
160 |
" \"kl-f8-anime2.ckpt\": \"https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n",
|
161 |
+
" \"anything-v4.0.vae.pt\": \"https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0.vae.pt\"}\n",
|
162 |
+
"if (merge_vae in vae_preset) and (not os.path.exists(merge_vae)):\n",
|
163 |
" wget.download(vae_preset[merge_vae])\n",
|
164 |
"\n",
|
165 |
"def upload_to_hugging_face(file_name):\n",
|
|
|
191 |
" return input_string[len(prefix):]\n",
|
192 |
" return input_string\n",
|
193 |
"\n",
|
194 |
+
"load_model = lambda m: safetensors.torch.load_file(m, device=\"cpu\") if os.path.splitext(m)[1] == \".safetensors\" else torch.load(m, map_location=torch.device(\"cpu\"))\n",
|
195 |
+
"save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
|
196 |
+
"\n",
|
197 |
+
"# --- def merge ---#\n",
|
198 |
+
"@torch.no_grad()\n",
|
199 |
+
"def merged(model_a, model_b, fab, fa, fb):\n",
|
200 |
+
" weights_a = load_model(model_a) if isinstance(model_a, str) else model_a\n",
|
201 |
+
" weights_b = load_model(model_b) if isinstance(model_b, str) else model_b\n",
|
202 |
+
" if \"state_dict\" in weights_a:\n",
|
203 |
+
" weights_a = weights_a[\"state_dict\"]\n",
|
204 |
+
" if \"state_dict\" in weights_b:\n",
|
205 |
+
" weights_b = weights_b[\"state_dict\"]\n",
|
206 |
+
" for key in list(weights_a.keys() or weights_b.keys()):\n",
|
207 |
+
" if isinstance(weights_a[key], dict):\n",
|
208 |
+
" del weights_a[key]\n",
|
209 |
+
" if isinstance(weights_b[key], dict):\n",
|
210 |
+
" del weights_b[key]\n",
|
211 |
+
" if key.startswith(\"model.\") or key.startswith(\"model_ema.\"):\n",
|
212 |
+
" if (key in weights_a) and (key in weights_b):\n",
|
213 |
+
" weights_a[key] = fab(weights_a[key], weights_b[key])\n",
|
214 |
+
" del weights_b[key]\n",
|
215 |
+
" elif key in weights_a:\n",
|
216 |
+
" weights_a[key] = fa(weights_a[key])\n",
|
217 |
+
" elif key in weights_b:\n",
|
218 |
+
" weights_a[key] = fb(weights_b[key])\n",
|
219 |
+
" del weights_b[key]\n",
|
220 |
+
" del weights_b\n",
|
221 |
+
" return weights_a\n",
|
222 |
+
"\n",
|
223 |
+
"def add(model_a, model_b):\n",
|
224 |
+
" return merged(model_a, model_b, lambda a, b: a + b, lambda a: a, lambda b: b)\n",
|
225 |
+
"\n",
|
226 |
+
"def difference(model_a, model_b):\n",
|
227 |
+
" return merged(model_a, model_b, lambda a, b: a - b, lambda a: a, lambda b: -b)\n",
|
228 |
+
"\n",
|
229 |
+
"def add_difference(model_a, model_b, model_c):\n",
|
230 |
+
" return add(model_a, difference(model_b, model_c))\n",
|
231 |
+
"# --- end merge ---#\n",
|
232 |
+
"\n",
|
233 |
"if models == \"\":\n",
|
234 |
+
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\") if not os.path.basename(m) in vae_preset]\n",
|
235 |
"else:\n",
|
236 |
" models = [m.strip() for m in models.split(\",\")]\n",
|
237 |
"\n",
|
238 |
+
"for i, model in enumerate(models):\n",
|
239 |
" model_name, model_ext = os.path.splitext(model)\n",
|
240 |
+
" # a.ckpt, - ,b.ckpt # - or b.ckpt\n",
|
241 |
+
" if (models[i] == \"-\") or (models[i - 1] == \"-\"):\n",
|
242 |
+
" continue\n",
|
243 |
" if model_ext == \".yaml\":\n",
|
244 |
" convert_yaml(model)\n",
|
245 |
+
" elif (model_ext != \".safetensors\") and (model_ext != \".ckpt\"):\n",
|
246 |
" print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
|
247 |
" else:\n",
|
|
|
|
|
248 |
" # convert model\n",
|
249 |
" with torch.no_grad():\n",
|
250 |
+
" # a.ckpt, - ,b.ckpt # a.ckpt\n",
|
251 |
+
" if (i < len(models) - 1) and (models[i + 1] == \"-\"):\n",
|
252 |
+
" weights = difference(model, models[i + 2])\n",
|
253 |
+
" model_name = f\"{model_name}-{os.path.splitext(models[i + 2])[0]}\"\n",
|
254 |
+
" # otherwise\n",
|
255 |
+
" else:\n",
|
256 |
+
" weights = load_model(model)\n",
|
257 |
" if \"state_dict\" in weights:\n",
|
258 |
" weights = weights[\"state_dict\"]\n",
|
259 |
+
" for key in list(weights.keys()):\n",
|
260 |
+
" if isinstance(weights[key], dict):\n",
|
261 |
+
" del weights[key] # to fix the broken model\n",
|
262 |
" if pruning:\n",
|
263 |
" model_name += \"-pruned\"\n",
|
264 |
" for key in list(weights.keys()):\n",
|
|
|
324 |
"id": "0SUK6Alv2ItS"
|
325 |
}
|
326 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
327 |
{
|
328 |
"cell_type": "markdown",
|
329 |
"source": [
|