{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n", "#### SD2.x系付属の.yamlも併せて変換します\n", "#### オプションでfp16として保存できます" ], "metadata": { "id": "fAIY_GORNEYa" } }, { "cell_type": "markdown", "source": [ "最初に以下のコードを実行" ], "metadata": { "id": "OnuCk_wNLM_D" } }, { "cell_type": "code", "source": [ "!pip install torch safetensors\n", "!pip install pytorch-lightning\n", "!pip install wget" ], "metadata": { "id": "pXr7oNJzwwgU" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Google Drive上のファイルを読み書きしたい場合は、以下のコードを実行" ], "metadata": { "id": "NsncqZOha2e0" } }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount(\"/content/drive\")" ], "metadata": { "id": "liEiK8Iioscq" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title 変換したモデルをHugging Faceに投稿したい場合は、以下のコードを実行\n", "#@markdown 1. [このページ](https://huggingface.co./settings/tokens)にアクセスしてNew tokenからName=適当, Role=writeでAccess Tokenを取得\n", "#@markdown 2. 取得したTokenをコピー & 以下の欄に貼り付け & 実行\n", "!pip install huggingface_hub\n", "from huggingface_hub import login\n", "token = \"\" #@param {type:\"string\"}\n", "login(token=token)" ], "metadata": { "cellView": "form", "id": "mJO8RdvIINA-" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行" ], "metadata": { "id": "7Ils-K70k15Y" } }, { "cell_type": "code", "source": [ "#@title モデルをダウンロード\n", "#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n", "#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n", "#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n", "import shutil\n", "import urllib.parse\n", "import urllib.request\n", "import wget\n", "import os\n", "\n", "models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n", "models = [m.strip() for m in models.split(\",\")]\n", "for model in models:\n", " if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n", " wget.download(model)\n", " elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\", \".pt\")):\n", " shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + os.path.basename(model)) # get the model from mydrive\n", " else:\n", " print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")" ], "metadata": { "id": "4vd3A09AxJE0", "cellView": "form" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title モデルを変換\n", "#@markdown 変換するモデルをカンマ区切りで任意個指定
\n", "#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される
\n", "#@markdown `a.ckpt, -, b.safetensors`のような形式でモデルの引き算ができます\n", "import os\n", "import glob\n", "import torch\n", "import safetensors.torch\n", "\n", "from sys import modules\n", "if \"huggingface_hub\" in modules:\n", " from huggingface_hub import HfApi, Repository\n", "\n", "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n", "pruning = True #@param {type:\"boolean\"}\n", "as_fp16 = True #@param {type:\"boolean\"}\n", "clip_fix = \"fix err key\" #@param [\"off\", \"fix err key\", \"del err key\"]\n", "uninvited_key = \"cond_stage_model.transformer.text_model.embeddings.position_ids\"\n", "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n", "merge_vae = \"\" #@param [\"\", \"vae-ft-mse-840000-ema-pruned.ckpt\", \"kl-f8-anime.ckpt\", \"kl-f8-anime2.ckpt\", \"anything-v4.0.vae.pt\"] {allow-input: true}\n", "save_directly_to_Google_Drive = False #@param {type:\"boolean\"}\n", "#@markdown 変換したモデルをHugging Faceに投稿する場合は「yourname/yourrepo」の形式で投稿先リポジトリを指定
\n", "#@markdown 投稿しない場合は何も入力しない
\n", "# 5GB以上のファイルを投稿する場合は、投稿先リポジトリを丸ごとダウンロードする工程が挟まるので、時間がかかる場合があります\n", "repo_id = \"\" #@param {type:\"string\"}\n", "\n", "vae_preset = {\n", " \"vae-ft-mse-840000-ema-pruned.ckpt\": \"https://huggingface.co./stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt\",\n", " \"kl-f8-anime.ckpt\": \"https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime.ckpt\",\n", " \"kl-f8-anime2.ckpt\": \"https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/vae/kl-f8-anime2.ckpt\",\n", " \"anything-v4.0.vae.pt\": \"https://huggingface.co./andite/anything-v4.0/resolve/main/anything-v4.0.vae.pt\"}\n", "if (merge_vae in vae_preset) and (not os.path.exists(merge_vae)):\n", " wget.download(vae_preset[merge_vae])\n", "\n", "def upload_to_hugging_face(file_name):\n", " api = HfApi()\n", " api.upload_file(path_or_fileobj=file_name,\n", " path_in_repo=file_name,\n", " repo_id=repo_id,\n", " )\n", "\n", "def convert_yaml(file_name):\n", " with open(file_name) as f:\n", " yaml = f.read()\n", " if save_directly_to_Google_Drive:\n", " os.chdir(\"/content/drive/MyDrive\")\n", " is_safe = save_type == \".safetensors\"\n", " yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n", " if as_fp16:\n", " yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n", " file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n", " with open(file_name, mode=\"w\") as f:\n", " f.write(yaml)\n", " if repo_id != \"\":\n", " upload_to_hugging_face(file_name)\n", " os.chdir(\"/content\")\n", "\n", "#use `str.removeprefix(p)` in python 3.9+\n", "def remove_prefix(input_string, prefix):\n", " if prefix and input_string.startswith(prefix):\n", " return input_string[len(prefix):]\n", " return input_string\n", "\n", "load_model = lambda m: safetensors.torch.load_file(m, device=\"cpu\") if os.path.splitext(m)[1] == \".safetensors\" else torch.load(m, map_location=torch.device(\"cpu\"))\n", "save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n", "\n", "# --- def merge ---#\n", "@torch.no_grad()\n", "def merged(model_a, model_b, fab, fa, fb):\n", " weights_a = load_model(model_a) if isinstance(model_a, str) else model_a\n", " weights_b = load_model(model_b) if isinstance(model_b, str) else model_b\n", " if \"state_dict\" in weights_a:\n", " weights_a = weights_a[\"state_dict\"]\n", " if \"state_dict\" in weights_b:\n", " weights_b = weights_b[\"state_dict\"]\n", " for key in list(weights_a.keys() or weights_b.keys()):\n", " if isinstance(weights_a[key], dict):\n", " del weights_a[key]\n", " if isinstance(weights_b[key], dict):\n", " del weights_b[key]\n", " if key.startswith(\"model.\") or key.startswith(\"model_ema.\"):\n", " if (key in weights_a) and (key in weights_b):\n", " weights_a[key] = fab(weights_a[key], weights_b[key])\n", " del weights_b[key]\n", " elif key in weights_a:\n", " weights_a[key] = fa(weights_a[key])\n", " elif key in weights_b:\n", " weights_a[key] = fb(weights_b[key])\n", " del weights_b[key]\n", " del weights_b\n", " return weights_a\n", "\n", "def add(model_a, model_b):\n", " return merged(model_a, model_b, lambda a, b: a + b, lambda a: a, lambda b: b)\n", "\n", "def difference(model_a, model_b):\n", " return merged(model_a, model_b, lambda a, b: a - b, lambda a: a, lambda b: -b)\n", "\n", "def add_difference(model_a, model_b, model_c):\n", " return add(model_a, difference(model_b, model_c))\n", "# --- end merge ---#\n", "\n", "if models == \"\":\n", " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\") if not os.path.basename(m) in vae_preset]\n", "else:\n", " models = [m.strip() for m in models.split(\",\")]\n", "\n", "for i, model in enumerate(models):\n", " model_name, model_ext = os.path.splitext(model)\n", " # a.ckpt, - ,b.ckpt # - or b.ckpt\n", " if (models[i] == \"-\") or (models[i - 1] == \"-\"):\n", " continue\n", " if model_ext == \".yaml\":\n", " convert_yaml(model)\n", " elif (model_ext != \".safetensors\") and (model_ext != \".ckpt\"):\n", " print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n", " else:\n", " # convert model\n", " with torch.no_grad():\n", " # a.ckpt, - ,b.ckpt # a.ckpt\n", " if (i < len(models) - 1) and (models[i + 1] == \"-\"):\n", " weights = difference(model, models[i + 2])\n", " model_name = f\"{model_name}-{os.path.splitext(models[i + 2])[0]}\"\n", " # otherwise\n", " else:\n", " weights = load_model(model)\n", " if \"state_dict\" in weights:\n", " weights = weights[\"state_dict\"]\n", " for key in list(weights.keys()):\n", " if isinstance(weights[key], dict):\n", " del weights[key] # to fix the broken model\n", " if pruning:\n", " model_name += \"-pruned\"\n", " for key in list(weights.keys()):\n", " if key.startswith(\"model_ema.\"):\n", " del weights[key]\n", " if as_fp16:\n", " model_name += \"-fp16\"\n", " for key in weights.keys():\n", " weights[key] = weights[key].half()\n", " if uninvited_key in weights:\n", " if clip_fix == \"del err key\":\n", " del weights[uninvited_key]\n", " if clip_fix == \"fix err key\":\n", " weights[uninvited_key] = torch.tensor([list(range(77))],dtype=torch.int64)\n", " if merge_vae != \"\":\n", " vae_weights = load_model(merge_vae)\n", " if \"state_dict\" in vae_weights:\n", " vae_weights = vae_weights[\"state_dict\"]\n", " for key in weights.keys():\n", " if key.startswith(\"first_stage_model.\"):\n", " weights[key] = vae_weights[remove_prefix(key, \"first_stage_model.\")]\n", " del vae_weights\n", " if save_directly_to_Google_Drive:\n", " os.chdir(\"/content/drive/MyDrive\")\n", " save_model(weights, saved_model := model_name + save_type)\n", " if repo_id != \"\":\n", " if os.path.getsize(saved_model) >= 5*1000*1000*1000:\n", " with Repository(os.path.basename(repo_id), clone_from=repo_id, skip_lfs_files=True, token=True).commit(commit_message=f\"Upload {saved_model} with huggingface_hub\", blocking=False):\n", " save_model(weights, saved_model)\n", " else:\n", " upload_to_hugging_face(saved_model)\n", " os.chdir(\"/content\")\n", " del weights\n", "\n", "!reset" ], "metadata": { "cellView": "form", "id": "QSzZqGygdXM9" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n", "\n", "指定方法はモデルと同じです" ], "metadata": { "id": "SWTFKmGFLec6" } }, { "cell_type": "markdown", "source": [ "メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n", "\n", "標準では10GBまでのモデルを変換できます" ], "metadata": { "id": "0SUK6Alv2ItS" } }, { "cell_type": "markdown", "source": [ "[モデルのリンク集](https://huggingface.co./models?other=stable-diffusion)等から好きなモデルを選ぼう" ], "metadata": { "id": "yaLq5Nqe6an6" } } ] }