ntt123's picture
add app
587b6c9
raw
history blame
1.86 kB
"""
Utility functions
"""
import pickle
from pathlib import Path
import pax
import toml
import yaml
from tacotron import Tacotron
def load_tacotron_config(config_file=Path("tacotron.toml")):
"""
Load the project configurations
"""
return toml.load(config_file)["tacotron"]
def load_tacotron_ckpt(net: pax.Module, optim: pax.Module, path):
"""
load checkpoint from disk
"""
with open(path, "rb") as f:
dic = pickle.load(f)
if net is not None:
net = net.load_state_dict(dic["model_state_dict"])
if optim is not None:
optim = optim.load_state_dict(dic["optim_state_dict"])
return dic["step"], net, optim
def create_tacotron_model(config):
"""
return a random initialized Tacotron model
"""
return Tacotron(
mel_dim=config["MEL_DIM"],
attn_bias=config["ATTN_BIAS"],
rr=config["RR"],
max_rr=config["MAX_RR"],
mel_min=config["MEL_MIN"],
sigmoid_noise=config["SIGMOID_NOISE"],
pad_token=config["PAD_TOKEN"],
prenet_dim=config["PRENET_DIM"],
attn_hidden_dim=config["ATTN_HIDDEN_DIM"],
attn_rnn_dim=config["ATTN_RNN_DIM"],
rnn_dim=config["RNN_DIM"],
postnet_dim=config["POSTNET_DIM"],
text_dim=config["TEXT_DIM"],
)
def load_wavegru_config(config_file):
"""
Load project configurations
"""
with open(config_file, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
def load_wavegru_ckpt(net, optim, ckpt_file):
"""
load training checkpoint from file
"""
with open(ckpt_file, "rb") as f:
dic = pickle.load(f)
if net is not None:
net = net.load_state_dict(dic["net_state_dict"])
if optim is not None:
optim = optim.load_state_dict(dic["optim_state_dict"])
return dic["step"], net, optim