subaqua commited on
Commit
afe0452
1 Parent(s): 2c89425

Upload as_safetensors+fp16_reversible.ipynb

Browse files
Files changed (1) hide show
  1. as_safetensors+fp16_reversible.ipynb +200 -0
as_safetensors+fp16_reversible.ipynb ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": []
7
+ },
8
+ "kernelspec": {
9
+ "name": "python3",
10
+ "display_name": "Python 3"
11
+ },
12
+ "language_info": {
13
+ "name": "python"
14
+ }
15
+ },
16
+ "cells": [
17
+ {
18
+ "cell_type": "markdown",
19
+ "source": [
20
+ "### モデルの形式 (.ckpt/.safetensors) を相互変換するスクリプトです\n",
21
+ "#### SD2.x系付属の.yamlも併せて変換します\n",
22
+ "#### オプションでfp16として保存できます"
23
+ ],
24
+ "metadata": {
25
+ "id": "fAIY_GORNEYa"
26
+ }
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "source": [
31
+ "以下のコードを上から順番に両方とも実行"
32
+ ],
33
+ "metadata": {
34
+ "id": "OnuCk_wNLM_D"
35
+ }
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "source": [
40
+ "from google.colab import drive \n",
41
+ "drive.mount(\"/content/drive\")"
42
+ ],
43
+ "metadata": {
44
+ "id": "liEiK8Iioscq"
45
+ },
46
+ "execution_count": null,
47
+ "outputs": []
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "source": [
52
+ "!pip install torch safetensors\n",
53
+ "!pip install pytorch-lightning\n",
54
+ "!pip install wget"
55
+ ],
56
+ "metadata": {
57
+ "id": "pXr7oNJzwwgU"
58
+ },
59
+ "execution_count": null,
60
+ "outputs": []
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "source": [
65
+ "以下のリンク等を任意のものに差し替えてから、以下のコードを上から順番に両方とも実行"
66
+ ],
67
+ "metadata": {
68
+ "id": "7Ils-K70k15Y"
69
+ }
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "source": [
74
+ "#@title <font size=\"-0\">モデルをダウンロード</font>\n",
75
+ "#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n",
76
+ "#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n",
77
+ "#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n",
78
+ "import shutil\n",
79
+ "import urllib.parse\n",
80
+ "import urllib.request\n",
81
+ "import wget\n",
82
+ "\n",
83
+ "models = \"Specify_the_model_in_this_way_if_the_model_is_on_My_Drive.safetensors, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.ckpt, https://huggingface.co/hakurei/waifu-diffusion-v1-4/resolve/main/wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
84
+ "models = [m.strip() for m in models.split(\",\")]\n",
85
+ "for model in models:\n",
86
+ " if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
87
+ " wget.download(model)\n",
88
+ " elif model.endswith((\".ckpt\", \".safetensors\")):\n",
89
+ " shutil.copy(\"/content/drive/MyDrive/\" + model, \"/content/\" + model) # get the model from mydrive\n",
90
+ " else:\n",
91
+ " print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
92
+ ],
93
+ "metadata": {
94
+ "cellView": "form",
95
+ "id": "4vd3A09AxJE0"
96
+ },
97
+ "execution_count": null,
98
+ "outputs": []
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "source": [
103
+ "#@title <font size=\"-0\">モデルを変換</font>\n",
104
+ "#@markdown 変換するモデルをカンマ区切りで任意個指定<br>\n",
105
+ "#@markdown 何も入力されていない場合は、読み込まれている全てのモデルが変換される\n",
106
+ "import os\n",
107
+ "import glob\n",
108
+ "import torch\n",
109
+ "import safetensors.torch\n",
110
+ "\n",
111
+ "models = \"wd-1-4-anime_e1.ckpt, wd-1-4-anime_e1.yaml\" #@param {type:\"string\"}\n",
112
+ "as_fp16 = True #@param {type:\"boolean\"}\n",
113
+ "save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
114
+ "save_type = \".safetensors\" #@param [\".safetensors\", \".ckpt\"]\n",
115
+ "\n",
116
+ "def convert_yaml(file_name):\n",
117
+ " with open(file_name) as f:\n",
118
+ " yaml = f.read()\n",
119
+ " if save_directly_to_Google_Drive:\n",
120
+ " os.chdir(\"/content/drive/MyDrive\")\n",
121
+ " is_safe = save_type == \".safetensors\"\n",
122
+ " yaml = yaml.replace(f\"use_checkpoint: {is_safe}\", f\"use_checkpoint: {not is_safe}\")\n",
123
+ " if as_fp16:\n",
124
+ " yaml = yaml.replace(\"use_fp16: False\", \"use_fp16: True\")\n",
125
+ " file_name = os.path.splitext(file_name)[0] + \"-fp16.yaml\"\n",
126
+ " with open(file_name, mode=\"w\") as f:\n",
127
+ " f.write(yaml)\n",
128
+ " os.chdir(\"/content\")\n",
129
+ "\n",
130
+ "if models == \"\":\n",
131
+ " models = [os.path.basename(m) for m in glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.yaml\")]\n",
132
+ "else:\n",
133
+ " models = [m.strip() for m in models.split(\",\")]\n",
134
+ "\n",
135
+ "for model in models:\n",
136
+ " model_name, model_ext = os.path.splitext(model)\n",
137
+ " if model_ext == \".yaml\":\n",
138
+ " convert_yaml(model)\n",
139
+ " elif (model_ext != \".safetensors\") & (model_ext != \".ckpt\"):\n",
140
+ " print(\"対応形式は.ckpt及び.safetensors並びに.yamlのみです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
141
+ " else:\n",
142
+ " load_model = safetensors.torch.load_file if model_ext == \".safetensors\" else torch.load\n",
143
+ " save_model = safetensors.torch.save_file if save_type == \".safetensors\" else torch.save\n",
144
+ " # convert model\n",
145
+ " with torch.no_grad():\n",
146
+ " weights = load_model(model)\n",
147
+ " if \"state_dict\" in weights:\n",
148
+ " weights = weights[\"state_dict\"]\n",
149
+ " if as_fp16:\n",
150
+ " model_name = model_name + \"-fp16\"\n",
151
+ " for key in weights.keys():\n",
152
+ " weights[key] = weights[key].half()\n",
153
+ " if save_directly_to_Google_Drive:\n",
154
+ " os.chdir(\"/content/drive/MyDrive\")\n",
155
+ " save_model(weights, model_name + save_type)\n",
156
+ " os.chdir(\"/content\")\n",
157
+ " del weights\n",
158
+ "\n",
159
+ "!reset"
160
+ ],
161
+ "metadata": {
162
+ "id": "9OmSG98HxJg2",
163
+ "cellView": "form"
164
+ },
165
+ "execution_count": null,
166
+ "outputs": []
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "source": [
171
+ "SD2.x系モデル等を変換する場合は、付属の設定ファイル (モデルと同名の.yamlファイル) も同時にダウンロード/変換しましょう\n",
172
+ "\n",
173
+ "指定方法はモデルと同じです"
174
+ ],
175
+ "metadata": {
176
+ "id": "SWTFKmGFLec6"
177
+ }
178
+ },
179
+ {
180
+ "cell_type": "markdown",
181
+ "source": [
182
+ "メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること\n",
183
+ "\n",
184
+ "標準では10GBまでのモデルを変換できます"
185
+ ],
186
+ "metadata": {
187
+ "id": "0SUK6Alv2ItS"
188
+ }
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "source": [
193
+ "モデルのリンク集: https://huggingface.co/models?other=stable-diffusion 等から好きなモデルを選ぼう"
194
+ ],
195
+ "metadata": {
196
+ "id": "yaLq5Nqe6an6"
197
+ }
198
+ }
199
+ ]
200
+ }