File size: 7,735 Bytes
afe0452 0632dfe afe0452 b678a86 afe0452 b678a86 afe0452 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n",
"#### SD2.x系付属の.yamlも併せて変換します\n",
"#### オプションでfp16として保存できます"
],
"metadata": {
"id": "fAIY_GORNEYa"
}
},
{
"cell_type": "markdown",
"source": [
"以下のコードを上から順番に両方とも実行"
],
"metadata": {
"id": "OnuCk_wNLM_D"
}
},
{
"cell_type": "code",
"source": [
"from google.colab import drive \n",
"drive.mount(\"/content/drive\")"
],
"metadata": {
"id": "liEiK8Iioscq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pip install torch safetensors\n",
"!pip install pytorch-lightning\n",
"!pip install wget"
],
"metadata": {
"id": "pXr7oNJzwwgU"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行"
],
"metadata": {
"id": "7Ils-K70k15Y"
}
},
{
"cell_type": "code",
"source": [
"#@title <font size=\"-0\">モデルをダウンロード</font>\n",
"#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n",
"#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n",
"#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n",
"import shutil\n",
"import urllib.parse\n",
"import urllib.request\n",
"import wget\n",
"\n",
"models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
"models = [m.strip() for m in models.split(\",\")]\n",
"for model in models:\n",
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
" wget.download(model)\n",
" elif model.endswith((\".ckpt\", \".safetensors\", \".yaml\")):\n",
" shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
" else:\n",
" print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
],
"metadata": {
"cellView": "form",
"id": "4vd3A09AxJE0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title <font size=\"-0\">モデルを変換</font>\n",
"#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
"#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される\n",
"import os\n",
"import glob\n",
"import torch\n",
"import safetensors.torch\n",
"from functools import partial\n",
"\n",
"models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
"as_fp16 = True #@param {type:\"boolean\"}\n",
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
"save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
"\n",
"def convert_yaml(file_name):\n",
" with open(file_name) as f:\n",
" yaml = f.read()\n",
" if save_directly_to_Google_Drive:\n",
" os.chdir(\"/content/drive/MyDrive\")\n",
" is_safe = save_type == \".safetensors\"\n",
" yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
" if as_fp16:\n",
" yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
" file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
" with open(file_name, mode=\"w\") as f:\n",
" f.write(yaml)\n",
" os.chdir(\"/content\")\n",
"\n",
"if models == \"\":\n",
" models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
"else:\n",
" models = [m.strip() for m in models.split(\",\")]\n",
"\n",
"for model in models:\n",
" model_name, model_ext = os.path.splitext(model)\n",
" if model_ext == \".yaml\":\n",
" convert_yaml(model)\n",
" elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
" print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
" else:\n",
" load_model = partial(safetensors.torch.load_file, device=\"cpu\") if model_ext == \".safetensors\" else partial(torch.load, map_location=torch.device(\"cpu\"))\n",
" save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
" # convert model\n",
" with torch.no_grad():\n",
" weights = load_model(model)\n",
" if \"state_dict\" in weights:\n",
" weights = weights[\"state_dict\"]\n",
" if as_fp16:\n",
" model_name = model_name + \"-fp16\"\n",
" for key in weights.keys():\n",
" weights[key] = weights[key].half()\n",
" if save_directly_to_Google_Drive:\n",
" os.chdir(\"/content/drive/MyDrive\")\n",
" save_model(weights, model_name + save_type)\n",
" os.chdir(\"/content\")\n",
" del weights\n",
"\n",
"!reset"
],
"metadata": {
"id": "9OmSG98HxJg2",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n",
"\n",
"指定方法はモデルと同じです"
],
"metadata": {
"id": "SWTFKmGFLec6"
}
},
{
"cell_type": "markdown",
"source": [
"メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n",
"\n",
"標準では10GBまでのモデルを変換できます"
],
"metadata": {
"id": "0SUK6Alv2ItS"
}
},
{
"cell_type": "markdown",
"source": [
"モデルのリンク集: https://huggingface.co./models?other=stable-diffusion 等から好きなモデルを選ぼう"
],
"metadata": {
"id": "yaLq5Nqe6an6"
}
}
]
} |