lxysl's picture
upload vita-1.5 app.py
bc752b1
raw
history blame
4.55 kB
import argparse
import importlib
import json
import os
from distutils.util import strtobool as dist_strtobool
import torch
import yaml
IGNORE_ID = -1
def assign_args_from_yaml(args, yaml_path, prefix_key=None):
with open(yaml_path) as f:
ydict = yaml.load(f, Loader=yaml.FullLoader)
if prefix_key is not None:
ydict = ydict[prefix_key]
for k, v in ydict.items():
k_args = k.replace("-", "_")
if hasattr(args, k_args):
setattr(args, k_args, ydict[k])
return args
def get_model_conf(model_path):
model_conf = os.path.dirname(model_path) + "/model.json"
with open(model_conf, "rb") as f:
print("reading a config file from " + model_conf)
confs = json.load(f)
# for asr, tts, mt
idim, odim, args = confs
return argparse.Namespace(**args)
def strtobool(x):
return bool(dist_strtobool(x))
def dynamic_import(import_path, alias=dict()):
"""dynamic import module and class
:param str import_path: syntax 'module_name:class_name'
e.g., 'espnet.transform.add_deltas:AddDeltas'
:param dict alias: shortcut for registered class
:return: imported class
"""
if import_path not in alias and ":" not in import_path:
raise ValueError(
"import_path should be one of {} or "
'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : '
"{}".format(set(alias), import_path)
)
if ":" not in import_path:
import_path = alias[import_path]
module_name, objname = import_path.split(":")
m = importlib.import_module(module_name)
return getattr(m, objname)
def set_deterministic_pytorch(args):
# seed setting
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = False
def pad_list(xs, pad_value):
n_batch = len(xs)
max_len = max(x.size(0) for x in xs)
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
for i in range(n_batch):
pad[i, : xs[i].size(0)] = xs[i]
return pad
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def subsequent_chunk_mask(
size: int,
ck_size: int,
num_l_cks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_l_cks < 0:
start = 0
else:
start = max((i // ck_size - num_l_cks) * ck_size, 0)
ending = min((i // ck_size + 1) * ck_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(
xs: torch.Tensor,
masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int,
static_chunk_size: int,
num_decoding_left_chunks: int,
):
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_l_cks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_l_cks = num_decoding_left_chunks
else:
chunk_size = torch.randint(1, max_len, (1,)).item()
num_l_cks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_l_cks = torch.randint(0, max_left_chunks, (1,)).item()
ck_masks = subsequent_chunk_mask(
xs.size(1), chunk_size, num_l_cks, xs.device
) # (L, L)
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
ck_masks = masks & ck_masks # (B, L, L)
elif static_chunk_size > 0:
num_l_cks = num_decoding_left_chunks
ck_masks = subsequent_chunk_mask(
xs.size(1), static_chunk_size, num_l_cks, xs.device
) # (L, L)
ck_masks = ck_masks.unsqueeze(0) # (1, L, L)
ck_masks = masks & ck_masks # (B, L, L)
else:
ck_masks = masks
return ck_masks