Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import importlib | |
import json | |
import os | |
from distutils.util import strtobool as dist_strtobool | |
import torch | |
import yaml | |
IGNORE_ID = -1 | |
def assign_args_from_yaml(args, yaml_path, prefix_key=None): | |
with open(yaml_path) as f: | |
ydict = yaml.load(f, Loader=yaml.FullLoader) | |
if prefix_key is not None: | |
ydict = ydict[prefix_key] | |
for k, v in ydict.items(): | |
k_args = k.replace("-", "_") | |
if hasattr(args, k_args): | |
setattr(args, k_args, ydict[k]) | |
return args | |
def get_model_conf(model_path): | |
model_conf = os.path.dirname(model_path) + "/model.json" | |
with open(model_conf, "rb") as f: | |
print("reading a config file from " + model_conf) | |
confs = json.load(f) | |
# for asr, tts, mt | |
idim, odim, args = confs | |
return argparse.Namespace(**args) | |
def strtobool(x): | |
return bool(dist_strtobool(x)) | |
def dynamic_import(import_path, alias=dict()): | |
"""dynamic import module and class | |
:param str import_path: syntax 'module_name:class_name' | |
e.g., 'espnet.transform.add_deltas:AddDeltas' | |
:param dict alias: shortcut for registered class | |
:return: imported class | |
""" | |
if import_path not in alias and ":" not in import_path: | |
raise ValueError( | |
"import_path should be one of {} or " | |
'include ":", e.g. "espnet.transform.add_deltas:AddDeltas" : ' | |
"{}".format(set(alias), import_path) | |
) | |
if ":" not in import_path: | |
import_path = alias[import_path] | |
module_name, objname = import_path.split(":") | |
m = importlib.import_module(module_name) | |
return getattr(m, objname) | |
def set_deterministic_pytorch(args): | |
# seed setting | |
torch.manual_seed(args.seed) | |
torch.backends.cudnn.deterministic = False | |
torch.backends.cudnn.benchmark = False | |
def pad_list(xs, pad_value): | |
n_batch = len(xs) | |
max_len = max(x.size(0) for x in xs) | |
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) | |
for i in range(n_batch): | |
pad[i, : xs[i].size(0)] = xs[i] | |
return pad | |
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: | |
batch_size = lengths.size(0) | |
max_len = max_len if max_len > 0 else lengths.max().item() | |
seq_range = torch.arange(0, max_len, dtype=torch.int64, device=lengths.device) | |
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
seq_length_expand = lengths.unsqueeze(-1) | |
mask = seq_range_expand >= seq_length_expand | |
return mask | |
def subsequent_chunk_mask( | |
size: int, | |
ck_size: int, | |
num_l_cks: int = -1, | |
device: torch.device = torch.device("cpu"), | |
) -> torch.Tensor: | |
ret = torch.zeros(size, size, device=device, dtype=torch.bool) | |
for i in range(size): | |
if num_l_cks < 0: | |
start = 0 | |
else: | |
start = max((i // ck_size - num_l_cks) * ck_size, 0) | |
ending = min((i // ck_size + 1) * ck_size, size) | |
ret[i, start:ending] = True | |
return ret | |
def add_optional_chunk_mask( | |
xs: torch.Tensor, | |
masks: torch.Tensor, | |
use_dynamic_chunk: bool, | |
use_dynamic_left_chunk: bool, | |
decoding_chunk_size: int, | |
static_chunk_size: int, | |
num_decoding_left_chunks: int, | |
): | |
if use_dynamic_chunk: | |
max_len = xs.size(1) | |
if decoding_chunk_size < 0: | |
chunk_size = max_len | |
num_l_cks = -1 | |
elif decoding_chunk_size > 0: | |
chunk_size = decoding_chunk_size | |
num_l_cks = num_decoding_left_chunks | |
else: | |
chunk_size = torch.randint(1, max_len, (1,)).item() | |
num_l_cks = -1 | |
if chunk_size > max_len // 2: | |
chunk_size = max_len | |
else: | |
chunk_size = chunk_size % 25 + 1 | |
if use_dynamic_left_chunk: | |
max_left_chunks = (max_len - 1) // chunk_size | |
num_l_cks = torch.randint(0, max_left_chunks, (1,)).item() | |
ck_masks = subsequent_chunk_mask( | |
xs.size(1), chunk_size, num_l_cks, xs.device | |
) # (L, L) | |
ck_masks = ck_masks.unsqueeze(0) # (1, L, L) | |
ck_masks = masks & ck_masks # (B, L, L) | |
elif static_chunk_size > 0: | |
num_l_cks = num_decoding_left_chunks | |
ck_masks = subsequent_chunk_mask( | |
xs.size(1), static_chunk_size, num_l_cks, xs.device | |
) # (L, L) | |
ck_masks = ck_masks.unsqueeze(0) # (1, L, L) | |
ck_masks = masks & ck_masks # (B, L, L) | |
else: | |
ck_masks = masks | |
return ck_masks | |