File size: 7,611 Bytes
afe0452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n",
        "#### SD2.x系付属の.yamlも併せて変換します\n",
        "#### オプションでfp16として保存できます"
      ],
      "metadata": {
        "id": "fAIY_GORNEYa"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "以下のコードを上から順番に両方とも実行"
      ],
      "metadata": {
        "id": "OnuCk_wNLM_D"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from google.colab import drive \n",
        "drive.mount(\"/content/drive\")"
      ],
      "metadata": {
        "id": "liEiK8Iioscq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install torch safetensors\n",
        "!pip install pytorch-lightning\n",
        "!pip install wget"
      ],
      "metadata": {
        "id": "pXr7oNJzwwgU"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行"
      ],
      "metadata": {
        "id": "7Ils-K70k15Y"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title <font size=\"-0\">モデルをダウンロード</font>\n",
        "#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n",
        "#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n",
        "#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n",
        "import shutil\n",
        "import urllib.parse\n",
        "import urllib.request\n",
        "import wget\n",
        "\n",
        "models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co./hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
        "models = [m.strip() for m in models.split(\",\")]\n",
        "for model in models:\n",
        "  if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
        "    wget.download(model)\n",
        "  elif model.endswith((\".ckpt\", \".safetensors\")):\n",
        "    shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
        "  else:\n",
        "    print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
      ],
      "metadata": {
        "cellView": "form",
        "id": "4vd3A09AxJE0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title <font size=\"-0\">モデルを変換</font>\n",
        "#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
        "#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される\n",
        "import os\n",
        "import glob\n",
        "import torch\n",
        "import safetensors.torch\n",
        "\n",
        "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
        "as_fp16 = True #@param {type:\"boolean\"}\n",
        "save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
        "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
        "\n",
        "def convert_yaml(file_name):\n",
        "  with open(file_name) as f:\n",
        "    yaml = f.read()\n",
        "  if save_directly_to_Google_Drive:\n",
        "    os.chdir(\"/content/drive/MyDrive\")\n",
        "  is_safe = save_type == \".safetensors\"\n",
        "  yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
        "  if as_fp16:\n",
        "    yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
        "    file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
        "  with open(file_name, mode=\"w\") as f:\n",
        "    f.write(yaml)\n",
        "  os.chdir(\"/content\")\n",
        "\n",
        "if models == \"\":\n",
        "  models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
        "else:\n",
        "  models = [m.strip() for m in models.split(\",\")]\n",
        "\n",
        "for model in models:\n",
        "  model_name, model_ext = os.path.splitext(model)\n",
        "  if model_ext == \".yaml\":\n",
        "    convert_yaml(model)\n",
        "  elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
        "    print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
        "  else:\n",
        "    load_model = safetensors.torch.load_file if model_ext == \".safetensors\" else torch.load\n",
        "    save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
        "    # convert model\n",
        "    with torch.no_grad():\n",
        "      weights = load_model(model)\n",
        "      if \"state_dict\" in weights:\n",
        "        weights = weights[\"state_dict\"]\n",
        "      if as_fp16:\n",
        "        model_name = model_name + \"-fp16\"\n",
        "        for key in weights.keys():\n",
        "          weights[key] = weights[key].half()\n",
        "      if save_directly_to_Google_Drive:\n",
        "        os.chdir(\"/content/drive/MyDrive\")\n",
        "      save_model(weights, model_name + save_type)\n",
        "      os.chdir(\"/content\")\n",
        "      del weights\n",
        "\n",
        "!reset"
      ],
      "metadata": {
        "id": "9OmSG98HxJg2",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n",
        "\n",
        "指定方法はモデルと同じです"
      ],
      "metadata": {
        "id": "SWTFKmGFLec6"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n",
        "\n",
        "標準では10GBまでのモデルを変換できます"
      ],
      "metadata": {
        "id": "0SUK6Alv2ItS"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "モデルのリンク集: https://huggingface.co./models?other=stable-diffusion 等から好きなモデルを選ぼう"
      ],
      "metadata": {
        "id": "yaLq5Nqe6an6"
      }
    }
  ]
}