{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["以下の2つのコードを両方とも実行"],"metadata":{"id":"OnuCk_wNLM_D"}},{"cell_type":"code","source":["from google.colab import drive \n","drive.mount(\"/content/drive\")"],"metadata":{"id":"liEiK8Iioscq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["!pip install torch safetensors\n","!pip install wget"],"metadata":{"id":"pXr7oNJzwwgU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["以下のリンク等を任意のものに差し替えてから、以下のコードを実行"],"metadata":{"id":"7Ils-K70k15Y"}},{"cell_type":"code","source":["#@title モデルをダウンロード\n","#@markdown {Google Drive上のモデル名 or モデルのダウンロードリンク} をカンマ区切りで任意個指定\n","#@markdown - Drive上のモデル名の場合...My Driveに対する相対パスで指定\n","#@markdown - ダウンロードリンクの場合...Hugging Face等のダウンロードリンクを右クリック & リンクのアドレスをコピー & 下のリンクの代わりに貼り付け\n","import shutil\n","import urllib.parse\n","import urllib.request\n","import wget\n","\n","models = \"Please use your own model in place of this example, example.safetensors, https://huggingface.co./stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt\" #@param {type:\"string\"}\n","models = [m.strip() for m in models.split(\",\") if not models == \"\"]\n","for model in models:\n"," if 0 < len(urllib.parse.urlparse(model).scheme): # if model is url\n"," wget.download(model)\n"," # once the bug on python 3.8 is fixed, replace the above code with the following code\n"," ## model_data = urllib.request.urlopen(model).read()\n"," ## with open(os.path.basename(model), mode=\"wb\") as f:\n"," ## f.write(model_data)\n"," elif model.endswith((\".ckpt\", \".safetensors\", \".pt\", \".pth\")):\n"," from_ = \"/content/drive/MyDrive/\" + model\n"," to_ = \"/content/\" + model\n"," shutil.copy(from_, to_)\n"," else:\n"," print(f\"\\\"{model}\\\"はURLではなく、正しい形式のファイルでもありません\")"],"metadata":{"cellView":"form","id":"4vd3A09AxJE0"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["この例のように、SD2.1等の比較的新しいモデルを利用する場合は、以下のコードを実行(さもなくばエラーが出力される)"],"metadata":{"id":"m1mHzOMjcDhz"}},{"cell_type":"code","source":["!pip install pytorch-lightning"],"metadata":{"id":"TkrmByc0aYVN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["以下の2つのコードから好きな方を選んで実行\n","\n","メモリ不足でクラッシュする場合は、より小さいモデルを利用するか、有料のハイメモリランタイムを使用すること"],"metadata":{"id":"0SUK6Alv2ItS"}},{"cell_type":"code","source":["#@title 自分でモデル名を指定する場合\n","import os\n","import torch\n","import safetensors.torch\n","\n","model = \"v2-1_768-ema-pruned.ckpt\" #@param {type:\"string\"}\n","model_name, model_ext = os.path.splitext(model)\n","as_fp16 = True #@param {type:\"boolean\"}\n","save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n","\n","with torch.no_grad():\n"," if model_ext == \".safetensors\":\n"," weights = safetensors.torch.load_file(model_name + model_ext, device=\"cpu\")\n"," elif model_ext == \".ckpt\":\n"," weights = torch.load(model_name + model_ext, map_location=torch.device('cpu'))[\"state_dict\"]\n"," else:\n"," raise Exception(\"対応形式は.ckptと.safetensorsです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n"," if as_fp16:\n"," model_name = model_name + \"-fp16\"\n"," for key in weights.keys():\n"," weights[key] = weights[key].half()\n"," if save_directly_to_Google_Drive:\n"," os.chdir(\"/content/drive/MyDrive\")\n"," safetensors.torch.save_file(weights, model_name + \".safetensors\")\n"," os.chdir(\"/content\")\n"," else:\n"," safetensors.torch.save_file(weights, model_name + \".safetensors\")\n"," del weights\n","\n","!reset"],"metadata":{"id":"9OmSG98HxJg2","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title 自動で全モデルを変換する場合\n","import os\n","import glob\n","import torch\n","import safetensors.torch\n","\n","as_fp16 = True #@param {type:\"boolean\"}\n","save_directly_to_Google_Drive = True #@param {type:\"boolean\"}\n","\n","with torch.no_grad():\n"," model_paths = glob.glob(r\"/content/*.ckpt\") + glob.glob(r\"/content/*.safetensors\") + glob.glob(r\"/content/*.pt\") + glob.glob(r\"/content/*.pth\")\n"," for model_path in model_paths:\n"," model_name, model_ext = os.path.splitext(os.path.basename(model_path))\n"," if model_ext == \".safetensors\":\n"," weights = safetensors.torch.load_file(model_name + model_ext, device=\"cpu\")\n"," elif model_ext == \".ckpt\":\n"," weights = torch.load(model_name + model_ext, map_location=torch.device('cpu'))[\"state_dict\"]\n"," else:\n"," print(\"対応形式は.ckpt\tと.safetensorsです\\n\" + f\"\\\"{model}\\\"は対応形式ではありません\")\n"," break\n"," if as_fp16:\n"," model_name = model_name + \"-fp16\"\n"," for key in weights.keys():\n"," weights[key] = weights[key].half()\n"," if save_directly_to_Google_Drive:\n"," os.chdir(\"/content/drive/MyDrive\")\n"," safetensors.torch.save_file(weights, model_name + \".safetensors\")\n"," os.chdir(\"/content\")\n"," else:\n"," safetensors.torch.save_file(weights, model_name + \".safetensors\")\n"," del weights\n","\n","!reset"],"metadata":{"id":"5TUvrW5VzLst","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["モデルのリンク集: https://huggingface.co./models?other=stable-diffusion 等から好きなモデルを選ぼう"],"metadata":{"id":"yaLq5Nqe6an6"}}]}