Upload as_safetensors+fp16_en.ipynb
Browse files- as_safetensors+fp16_en.ipynb +219 -0
as_safetensors+fp16_en.ipynb
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"nbformat": 4,
|
3 |
+
"nbformat_minor": 0,
|
4 |
+
"metadata": {
|
5 |
+
"colab": {
|
6 |
+
"provenance": []
|
7 |
+
},
|
8 |
+
"kernelspec": {
|
9 |
+
"name": "python3",
|
10 |
+
"display_name": "Python 3"
|
11 |
+
},
|
12 |
+
"language_info": {
|
13 |
+
"name": "python"
|
14 |
+
}
|
15 |
+
},
|
16 |
+
"cells": [
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"source": [
|
20 |
+
"Run both of the following two codes"
|
21 |
+
],
|
22 |
+
"metadata": {
|
23 |
+
"id": "OnuCk_wNLM_D"
|
24 |
+
}
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "code",
|
28 |
+
"source": [
|
29 |
+
"from google.colab import drive \n",
|
30 |
+
"drive.mount(\"/content/drive\")"
|
31 |
+
],
|
32 |
+
"metadata": {
|
33 |
+
"id": "liEiK8Iioscq"
|
34 |
+
},
|
35 |
+
"execution_count": null,
|
36 |
+
"outputs": []
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"cell_type": "code",
|
40 |
+
"source": [
|
41 |
+
"!pip install torch safetensors\n",
|
42 |
+
"!pip install wget"
|
43 |
+
],
|
44 |
+
"metadata": {
|
45 |
+
"id": "pXr7oNJzwwgU"
|
46 |
+
},
|
47 |
+
"execution_count": null,
|
48 |
+
"outputs": []
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "markdown",
|
52 |
+
"source": [
|
53 |
+
"Replace the following links, etc. with the desired ones and then run the following code"
|
54 |
+
],
|
55 |
+
"metadata": {
|
56 |
+
"id": "7Ils-K70k15Y"
|
57 |
+
}
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"source": [
|
62 |
+
"#@title <font size=\"-0\">Download Models</font>\n",
|
63 |
+
"#@markdown Please specify the model name or download link for Google Drive, separated by commas\n",
|
64 |
+
"#@markdown - If it is the model name on Google Drive, specify it as a relative path to My Drive\n",
|
65 |
+
"#@markdown - If it is a download link, copy the link address by right-clicking and paste it in place of the link below\n",
|
66 |
+
"\n",
|
67 |
+
"import shutil\n",
|
68 |
+
"import urllib.parse\n",
|
69 |
+
"import urllib.request\n",
|
70 |
+
"import wget\n",
|
71 |
+
"\n",
|
72 |
+
"models = \"Please use your own model in place of this example, example.safetensors, https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt\" #@param {type:\"string\"}\n",
|
73 |
+
"models = [m.strip() for m in models.split(\",\") if not models == \"\"]\n",
|
74 |
+
"for model in models:\n",
|
75 |
+
" if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n",
|
76 |
+
" wget.download(model)\n",
|
77 |
+
" # once the bug on python 3.8 is fixed, replace the above code with the following code\n",
|
78 |
+
" ## model_data = urllib.request.urlopen(model).read()\n",
|
79 |
+
" ## with open(os.path.basename(model), mode=\"wb\") as f:\n",
|
80 |
+
" ## f.write(model_data)\n",
|
81 |
+
" elif model.endswith((\".ckpt\", \".safetensors\", \".pt\", \".pth\")):\n",
|
82 |
+
" from_ = \"/content/drive/MyDrive/\" + model\n",
|
83 |
+
" to_ = \"/content/\" + model\n",
|
84 |
+
" shutil.copy(from_, to_)\n",
|
85 |
+
" else:\n",
|
86 |
+
" print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"
|
87 |
+
],
|
88 |
+
"metadata": {
|
89 |
+
"cellView": "form",
|
90 |
+
"id": "4vd3A09AxJE0"
|
91 |
+
},
|
92 |
+
"execution_count": null,
|
93 |
+
"outputs": []
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "markdown",
|
97 |
+
"source": [
|
98 |
+
"if you use a relatively newer model such as SD2.1, run the following code"
|
99 |
+
],
|
100 |
+
"metadata": {
|
101 |
+
"id": "m1mHzOMjcDhz"
|
102 |
+
}
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"source": [
|
107 |
+
"!pip install pytorch-lightning"
|
108 |
+
],
|
109 |
+
"metadata": {
|
110 |
+
"id": "TkrmByc0aYVN"
|
111 |
+
},
|
112 |
+
"execution_count": null,
|
113 |
+
"outputs": []
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "markdown",
|
117 |
+
"source": [
|
118 |
+
"Run either of the following two codes. If you run out of memory and crash, use a smaller model or a paid high-memory runtime"
|
119 |
+
],
|
120 |
+
"metadata": {
|
121 |
+
"id": "0SUK6Alv2ItS"
|
122 |
+
}
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"source": [
|
127 |
+
"#@title <font size=\"-0\">If you specify the name of the model you want to convert and convert it manually</font>\n",
|
128 |
+
"import os\n",
|
129 |
+
"import torch\n",
|
130 |
+
"import safetensors.torch\n",
|
131 |
+
"\n",
|
132 |
+
"model = \"v2-1_768-ema-pruned.ckpt\" #@param {type:\"string\"}\n",
|
133 |
+
"model_name, model_ext = os.path.splitext(model)\n",
|
134 |
+
"as_fp16 = True #@param {type:\"boolean\"}\n",
|
135 |
+
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
|
136 |
+
"\n",
|
137 |
+
"with torch.no_grad():\n",
|
138 |
+
" if model_ext == \".safetensors\":\n",
|
139 |
+
" weights = safetensors.torch.load_file(model_name + model_ext, device=\"cpu\")\n",
|
140 |
+
" elif model_ext == \".ckpt\":\n",
|
141 |
+
" weights = torch.load(model_name + model_ext, map_location=torch.device('cpu'))[\"state_dict\"]\n",
|
142 |
+
" else:\n",
|
143 |
+
" raise Exception(\"対応形式は.ckptと.safetensorsです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
|
144 |
+
" if as_fp16:\n",
|
145 |
+
" model_name = model_name + \"-fp16\"\n",
|
146 |
+
" for key in weights.keys():\n",
|
147 |
+
" weights[key] = weights[key].half()\n",
|
148 |
+
" if save_directly_to_Google_Drive:\n",
|
149 |
+
" os.chdir(\"/content/drive/MyDrive\")\n",
|
150 |
+
" safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
|
151 |
+
" os.chdir(\"/content\")\n",
|
152 |
+
" else:\n",
|
153 |
+
" safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
|
154 |
+
" del weights\n",
|
155 |
+
"\n",
|
156 |
+
"!reset"
|
157 |
+
],
|
158 |
+
"metadata": {
|
159 |
+
"id": "9OmSG98HxJg2",
|
160 |
+
"cellView": "form"
|
161 |
+
},
|
162 |
+
"execution_count": null,
|
163 |
+
"outputs": []
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"source": [
|
168 |
+
"#@title <font size=\"-0\">If you automatically convert all pre-loaded models</font>\n",
|
169 |
+
"import os\n",
|
170 |
+
"import glob\n",
|
171 |
+
"import torch\n",
|
172 |
+
"import safetensors.torch\n",
|
173 |
+
"\n",
|
174 |
+
"as_fp16 = True #@param {type:\"boolean\"}\n",
|
175 |
+
"save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n",
|
176 |
+
"\n",
|
177 |
+
"with torch.no_grad():\n",
|
178 |
+
" model_paths = glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.pt\") + glob.glob(r\"/content/*.pth\")\n",
|
179 |
+
" for model_path in model_paths:\n",
|
180 |
+
" model_name, model_ext = os.path.splitext(os.path.basename(model_path))\n",
|
181 |
+
" if model_ext == \".safetensors\":\n",
|
182 |
+
" weights = safetensors.torch.load_file(model_name + model_ext, device=\"cpu\")\n",
|
183 |
+
" elif model_ext == \".ckpt\":\n",
|
184 |
+
" weights = torch.load(model_name + model_ext, map_location=torch.device('cpu'))[\"state_dict\"]\n",
|
185 |
+
" else:\n",
|
186 |
+
" print(\"対応形式は.ckpt\tと.safetensorsです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n",
|
187 |
+
" break\n",
|
188 |
+
" if as_fp16:\n",
|
189 |
+
" model_name = model_name + \"-fp16\"\n",
|
190 |
+
" for key in weights.keys():\n",
|
191 |
+
" weights[key] = weights[key].half()\n",
|
192 |
+
" if save_directly_to_Google_Drive:\n",
|
193 |
+
" os.chdir(\"/content/drive/MyDrive\")\n",
|
194 |
+
" safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
|
195 |
+
" os.chdir(\"/content\")\n",
|
196 |
+
" else:\n",
|
197 |
+
" safetensors.torch.save_file(weights, model_name + \".safetensors\")\n",
|
198 |
+
" del weights\n",
|
199 |
+
"\n",
|
200 |
+
"!reset"
|
201 |
+
],
|
202 |
+
"metadata": {
|
203 |
+
"id": "5TUvrW5VzLst",
|
204 |
+
"cellView": "form"
|
205 |
+
},
|
206 |
+
"execution_count": null,
|
207 |
+
"outputs": []
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "markdown",
|
211 |
+
"source": [
|
212 |
+
"Choose your favorite model from https://huggingface.co/models?other=stable-diffusion or other model link collections"
|
213 |
+
],
|
214 |
+
"metadata": {
|
215 |
+
"id": "yaLq5Nqe6an6"
|
216 |
+
}
|
217 |
+
}
|
218 |
+
]
|
219 |
+
}
|