import os import torch import hashlib import datetime from collections import OrderedDict def replace_keys_in_dict(d, old_key_part, new_key_part): # Use OrderedDict if the original is an OrderedDict if isinstance(d, OrderedDict): updated_dict = OrderedDict() else: updated_dict = {} for key, value in d.items(): # Replace the key part if found new_key = key.replace(old_key_part, new_key_part) # If the value is a dictionary, apply the function recursively if isinstance(value, dict): value = replace_keys_in_dict(value, old_key_part, new_key_part) updated_dict[new_key] = value return updated_dict def extract_small_model( path: str, name: str, sr: int, pitch_guidance: bool, version: str, epoch: int, step: int, ): try: ckpt = torch.load(path, map_location="cpu") pth_file = f"{name}.pth" pth_file_old_version_path = os.path.join("logs", f"{pth_file}_old_version.pth") opt = OrderedDict( weight={ key: value.half() for key, value in ckpt.items() if "enc_q" not in key } ) if "model" in ckpt: ckpt = ckpt["model"] opt = OrderedDict() opt["weight"] = {} for key in ckpt.keys(): if "enc_q" in key: continue opt["weight"][key] = ckpt[key].half() if sr == "40000": opt["config"] = [ 1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4], 109, 256, 40000, ] elif sr == "48000": if version == "v1": opt["config"] = [ 1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 6, 2, 2, 2], 512, [16, 16, 4, 4, 4], 109, 256, 48000, ] else: opt["config"] = [ 1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [12, 10, 2, 2], 512, [24, 20, 4, 4], 109, 256, 48000, ] elif sr == "32000": if version == "v1": opt["config"] = [ 513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4, 4], 109, 256, 32000, ] else: opt["config"] = [ 513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 8, 2, 2], 512, [20, 16, 4, 4], 109, 256, 32000, ] opt["epoch"] = epoch opt["step"] = step opt["sr"] = sr opt["f0"] = int(pitch_guidance) opt["version"] = version opt["creation_date"] = datetime.datetime.now().isoformat() hash_input = f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}" model_hash = hashlib.sha256(hash_input.encode()).hexdigest() opt["model_hash"] = model_hash model = torch.load(pth_file_old_version_path, map_location=torch.device("cpu")) torch.save( replace_keys_in_dict( replace_keys_in_dict( model, ".parametrizations.weight.original1", ".weight_v" ), ".parametrizations.weight.original0", ".weight_g", ), pth_file_old_version_path, ) os.remove(pth_file_old_version_path) os.rename(pth_file_old_version_path, pth_file) except Exception as error: print(f"An error occurred extracting the model: {error}")