File size: 1,494 Bytes
aaf8596 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
"""Helper script to convert diffusion checkpoints to format used by image generator."""
import os
from absl import app
from absl import flags
import requests
import torch as th
_CKPT_PATH = flags.DEFINE_string(
"ckpt_path", default=None, help="Path to checkpoint file", required=True)
_OUTPUT_PATH = flags.DEFINE_string(
"output_path", default="bins", help="Output folder path", required=False)
VOCAB_URL = "https://openaipublic.blob.core.windows.net/clip/bpe_simple_vocab_16e6.txt"
def run(ckpt_path, output_path):
"""Converts the checkpoint and saves the result.
Args:
ckpt_path: Source checkpoint path
output_path: Result folder directory
"""
os.makedirs(output_path, exist_ok=True)
ckpt = th.load(ckpt_path, map_location="cpu")
vocab_dest = os.path.join(output_path, os.path.basename(VOCAB_URL))
if not os.path.exists(vocab_dest):
with requests.get(VOCAB_URL, stream=True) as response:
with open(vocab_dest, "wb") as vocab_file:
for c in response.iter_content(chunk_size=8192):
vocab_file.write(c)
for k, v in ckpt["state_dict"].items():
if "first_stage_model.encoder" in k:
continue
if not hasattr(v, "numpy"):
continue
output_bin_file = os.path.join(output_path, f"{k}.bin")
v.numpy().astype("float16").tofile(output_bin_file)
def main(_) -> None:
ckpt_path = _CKPT_PATH.value
output_path = _OUTPUT_PATH.value
run(ckpt_path, output_path)
if __name__ == "__main__":
app.run(main)
|