File size: 1,494 Bytes
aaf8596
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Helper script to convert diffusion checkpoints to format used by image generator."""

import os

from absl import app
from absl import flags
import requests
import torch as th


_CKPT_PATH = flags.DEFINE_string(
    "ckpt_path", default=None, help="Path to checkpoint file", required=True)
_OUTPUT_PATH = flags.DEFINE_string(
    "output_path", default="bins", help="Output folder path", required=False)

VOCAB_URL = "https://openaipublic.blob.core.windows.net/clip/bpe_simple_vocab_16e6.txt"


def run(ckpt_path, output_path):
  """Converts the checkpoint and saves the result.

  Args:
    ckpt_path: Source checkpoint path
    output_path: Result folder directory
  """
  os.makedirs(output_path, exist_ok=True)
  ckpt = th.load(ckpt_path, map_location="cpu")

  vocab_dest = os.path.join(output_path, os.path.basename(VOCAB_URL))
  if not os.path.exists(vocab_dest):
    with requests.get(VOCAB_URL, stream=True) as response:
      with open(vocab_dest, "wb") as vocab_file:
        for c in response.iter_content(chunk_size=8192):
          vocab_file.write(c)

  for k, v in ckpt["state_dict"].items():
    if "first_stage_model.encoder" in k:
      continue
    if not hasattr(v, "numpy"):
      continue
    output_bin_file = os.path.join(output_path, f"{k}.bin")
    v.numpy().astype("float16").tofile(output_bin_file)


def main(_) -> None:
  ckpt_path = _CKPT_PATH.value
  output_path = _OUTPUT_PATH.value
  run(ckpt_path, output_path)

if __name__ == "__main__":
  app.run(main)