|
"""Helper script to convert diffusion checkpoints to format used by image generator.""" |
|
|
|
import os |
|
|
|
from absl import app |
|
from absl import flags |
|
import requests |
|
import torch as th |
|
|
|
|
|
_CKPT_PATH = flags.DEFINE_string( |
|
"ckpt_path", default=None, help="Path to checkpoint file", required=True) |
|
_OUTPUT_PATH = flags.DEFINE_string( |
|
"output_path", default="bins", help="Output folder path", required=False) |
|
|
|
VOCAB_URL = "https://openaipublic.blob.core.windows.net/clip/bpe_simple_vocab_16e6.txt" |
|
|
|
|
|
def run(ckpt_path, output_path): |
|
"""Converts the checkpoint and saves the result. |
|
|
|
Args: |
|
ckpt_path: Source checkpoint path |
|
output_path: Result folder directory |
|
""" |
|
os.makedirs(output_path, exist_ok=True) |
|
ckpt = th.load(ckpt_path, map_location="cpu") |
|
|
|
vocab_dest = os.path.join(output_path, os.path.basename(VOCAB_URL)) |
|
if not os.path.exists(vocab_dest): |
|
with requests.get(VOCAB_URL, stream=True) as response: |
|
with open(vocab_dest, "wb") as vocab_file: |
|
for c in response.iter_content(chunk_size=8192): |
|
vocab_file.write(c) |
|
|
|
for k, v in ckpt["state_dict"].items(): |
|
if "first_stage_model.encoder" in k: |
|
continue |
|
if not hasattr(v, "numpy"): |
|
continue |
|
output_bin_file = os.path.join(output_path, f"{k}.bin") |
|
v.numpy().astype("float16").tofile(output_bin_file) |
|
|
|
|
|
def main(_) -> None: |
|
ckpt_path = _CKPT_PATH.value |
|
output_path = _OUTPUT_PATH.value |
|
run(ckpt_path, output_path) |
|
|
|
if __name__ == "__main__": |
|
app.run(main) |
|
|