"""E2E VC training / decoding functions.""" |
import copy |
import json |
import logging |
import math |
import os |
import time |
import chainer |
import kaldiio |
import numpy as np |
import torch |
from chainer import training |
from chainer.training import extensions |
from espnet.asr.asr_utils import get_model_conf |
from espnet.asr.asr_utils import snapshot_object |
from espnet.asr.asr_utils import torch_load |
from espnet.asr.asr_utils import torch_resume |
from espnet.asr.asr_utils import torch_snapshot |
from espnet.asr.pytorch_backend.asr_init import load_trained_modules |
from espnet.nets.pytorch_backend.nets_utils import pad_list |
from espnet.nets.tts_interface import TTSInterface |
from espnet.utils.dataset import ChainerDataLoader |
from espnet.utils.dataset import TransformDataset |
from espnet.utils.dynamic_import import dynamic_import |
from espnet.utils.io_utils import LoadInputsAndTargets |
from espnet.utils.training.batchfy import make_batchset |
from espnet.utils.training.evaluator import BaseEvaluator |
from espnet.utils.deterministic_utils import set_deterministic_pytorch |
from espnet.utils.training.train_utils import check_early_stop |
from espnet.utils.training.train_utils import set_early_stop |
from espnet.utils.training.iterators import ShufflingEnabler |
import matplotlib |
from espnet.utils.training.tensorboard_logger import TensorboardLogger |
from tensorboardX import SummaryWriter |
matplotlib.use("Agg") |
class CustomEvaluator(BaseEvaluator): |
"""Custom evaluator.""" |
def __init__(self, model, iterator, target, device): |
"""Initilize module. |
Args: |
model (torch.nn.Module): Pytorch model instance. |
iterator (chainer.dataset.Iterator): Iterator for validation. |
target (chainer.Chain): Dummy chain instance. |
device (torch.device): The device to be used in evaluation. |
""" |
super(CustomEvaluator, self).__init__(iterator, target) |
self.model = model |
self.device = device |
def evaluate(self): |
"""Evaluate over validation iterator.""" |
iterator = self._iterators["main"] |
if self.eval_hook: |
self.eval_hook(self) |
if hasattr(iterator, "reset"): |
iterator.reset() |
it = iterator |
else: |
it = copy.copy(iterator) |
summary = chainer.reporter.DictSummary() |
self.model.eval() |
with torch.no_grad(): |
for batch in it: |
if isinstance(batch, tuple): |
x = tuple(arr.to(self.device) for arr in batch) |
else: |
x = batch |
for key in x.keys(): |
x[key] = x[key].to(self.device) |
observation = {} |
with chainer.reporter.report_scope(observation): |
if isinstance(x, tuple): |
self.model(*x) |
else: |
self.model(**x) |
summary.add(observation) |
self.model.train() |
return summary.compute_mean() |
class CustomUpdater(training.StandardUpdater): |
"""Custom updater.""" |
def __init__(self, model, grad_clip, iterator, optimizer, device, accum_grad=1): |
"""Initilize module. |
Args: |
model (torch.nn.Module) model: Pytorch model instance. |
grad_clip (float) grad_clip : The gradient clipping value. |
iterator (chainer.dataset.Iterator): Iterator for training. |
optimizer (torch.optim.Optimizer) : Pytorch optimizer instance. |
device (torch.device): The device to be used in training. |
""" |
super(CustomUpdater, self).__init__(iterator, optimizer) |
self.model = model |
self.grad_clip = grad_clip |
self.device = device |
self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ |
self.accum_grad = accum_grad |
self.forward_count = 0 |
def update_core(self): |
"""Update model one step.""" |
train_iter = self.get_iterator("main") |
optimizer = self.get_optimizer("main") |
batch = train_iter.next() |
if isinstance(batch, tuple): |
x = tuple(arr.to(self.device) for arr in batch) |
else: |
x = batch |
for key in x.keys(): |
x[key] = x[key].to(self.device) |
if isinstance(x, tuple): |
loss = self.model(*x).mean() / self.accum_grad |
else: |
loss = self.model(**x).mean() / self.accum_grad |
loss.backward() |
self.forward_count += 1 |
if self.forward_count != self.accum_grad: |
return |
self.forward_count = 0 |
grad_norm = self.clip_grad_norm(self.model.parameters(), self.grad_clip) |
logging.debug("grad norm={}".format(grad_norm)) |
if math.isnan(grad_norm): |
logging.warning("grad norm is nan. Do not update model.") |
else: |
optimizer.step() |
optimizer.zero_grad() |
def update(self): |
"""Run update function.""" |
self.update_core() |
if self.forward_count == 0: |
self.iteration += 1 |
class CustomConverter(object): |
"""Custom converter.""" |
def __init__(self): |
"""Initilize module.""" |
pass |
def __call__(self, batch, device=torch.device("cpu")): |
"""Convert a given batch. |
Args: |
batch (list): List of ndarrays. |
device (torch.device): The device to be send. |
Returns: |
dict: Dict of converted tensors. |
Examples: |
>>> batch = [([np.arange(5), np.arange(3)], |
[np.random.randn(8, 2), np.random.randn(4, 2)], |
None, None)] |
>>> conveter = CustomConverter() |
>>> conveter(batch, torch.device("cpu")) |
{'xs': tensor([[0, 1, 2, 3, 4], |
[0, 1, 2, 0, 0]]), |
'ilens': tensor([5, 3]), |
'ys': tensor([[[-0.4197, -1.1157], |
[-1.5837, -0.4299], |
[-2.0491, 0.9215], |
[-2.4326, 0.8891], |
[ 1.2323, 1.7388], |
[-0.3228, 0.6656], |
[-0.6025, 1.3693], |
[-1.0778, 1.3447]], |
[[ 0.1768, -0.3119], |
[ 0.4386, 2.5354], |
[-1.2181, -0.5918], |
[-0.6858, -0.8843], |
[ 0.0000, 0.0000], |
[ 0.0000, 0.0000], |
[ 0.0000, 0.0000], |
[ 0.0000, 0.0000]]]), |
'labels': tensor([[0., 0., 0., 0., 0., 0., 0., 1.], |
[0., 0., 0., 1., 1., 1., 1., 1.]]), |
'olens': tensor([8, 4])} |
""" |
assert len(batch) == 1 |
xs, ys, spembs, extras = batch[0] |
ilens = torch.from_numpy(np.array([x.shape[0] for x in xs])).long().to(device) |
olens = torch.from_numpy(np.array([y.shape[0] for y in ys])).long().to(device) |
xs = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(device) |
ys = pad_list([torch.from_numpy(y).float() for y in ys], 0).to(device) |
labels = ys.new_zeros(ys.size(0), ys.size(1)) |
for i, l in enumerate(olens): |
labels[i, l - 1 :] = 1.0 |
new_batch = { |
"xs": xs, |
"ilens": ilens, |
"ys": ys, |
"labels": labels, |
"olens": olens, |
} |
if spembs is not None: |
spembs = torch.from_numpy(np.array(spembs)).float() |
new_batch["spembs"] = spembs.to(device) |
if extras is not None: |
extras = pad_list([torch.from_numpy(extra).float() for extra in extras], 0) |
new_batch["extras"] = extras.to(device) |
return new_batch |
def train(args): |
"""Train E2E VC model.""" |
set_deterministic_pytorch(args) |
if not torch.cuda.is_available(): |
logging.warning("cuda is not available") |
with open(args.valid_json, "rb") as f: |
valid_json = json.load(f)["utts"] |
utts = list(valid_json.keys()) |
idim = int(valid_json[utts[0]]["input"][0]["shape"][1]) |
odim = int(valid_json[utts[0]]["output"][0]["shape"][1]) |
logging.info("#input dims : " + str(idim)) |
logging.info("#output dims: " + str(odim)) |
if args.use_speaker_embedding: |
args.spk_embed_dim = int(valid_json[utts[0]]["input"][1]["shape"][0]) |
else: |
args.spk_embed_dim = None |
if args.use_second_target: |
args.spc_dim = int(valid_json[utts[0]]["input"][1]["shape"][1]) |
else: |
args.spc_dim = None |
if not os.path.exists(args.outdir): |
os.makedirs(args.outdir) |
model_conf = args.outdir + "/model.json" |
with open(model_conf, "wb") as f: |
logging.info("writing a model config file to" + model_conf) |
f.write( |
json.dumps( |
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True |
).encode("utf_8") |
) |
for key in sorted(vars(args).keys()): |
logging.info("ARGS: " + key + ": " + str(vars(args)[key])) |
if args.enc_init is not None or args.dec_init is not None: |
model = load_trained_modules(idim, odim, args, TTSInterface) |
else: |
model_class = dynamic_import(args.model_module) |
model = model_class(idim, odim, args) |
assert isinstance(model, TTSInterface) |
logging.info(model) |
reporter = model.reporter |
if args.freeze_mods: |
for mod, param in model.named_parameters(): |
if any(mod.startswith(key) for key in args.freeze_mods): |
logging.info("freezing %s" % mod) |
param.requires_grad = False |
for mod, param in model.named_parameters(): |
if not param.requires_grad: |
logging.info("Frozen module %s" % mod) |
if args.ngpu > 1: |
model = torch.nn.DataParallel(model, device_ids=list(range(args.ngpu))) |
if args.batch_size != 0: |
logging.warning( |
"batch size is automatically increased (%d -> %d)" |
% (args.batch_size, args.batch_size * args.ngpu) |
) |
args.batch_size *= args.ngpu |
device = torch.device("cuda" if args.ngpu > 0 else "cpu") |
model = model.to(device) |
logging.warning( |
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format( |
sum(p.numel() for p in model.parameters()), |
sum(p.numel() for p in model.parameters() if p.requires_grad), |
sum(p.numel() for p in model.parameters() if p.requires_grad) |
* 100.0 |
/ sum(p.numel() for p in model.parameters()), |
) |
) |
if args.opt == "adam": |
optimizer = torch.optim.Adam( |
model.parameters(), args.lr, eps=args.eps, weight_decay=args.weight_decay |
) |
elif args.opt == "noam": |
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt |
optimizer = get_std_opt( |
model, args.adim, args.transformer_warmup_steps, args.transformer_lr |
) |
elif args.opt == "lamb": |
from pytorch_lamb import Lamb |
optimizer = Lamb( |
model.parameters(), lr=args.lr, weight_decay=0.01, betas=(0.9, 0.999) |
) |
else: |
raise NotImplementedError("unknown optimizer: " + args.opt) |
setattr(optimizer, "target", reporter) |
setattr(optimizer, "serialize", lambda s: reporter.serialize(s)) |
with open(args.train_json, "rb") as f: |
train_json = json.load(f)["utts"] |
with open(args.valid_json, "rb") as f: |
valid_json = json.load(f)["utts"] |
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0 |
if use_sortagrad: |
args.batch_sort_key = "input" |
train_batchset = make_batchset( |
train_json, |
args.batch_size, |
args.maxlen_in, |
args.maxlen_out, |
args.minibatches, |
batch_sort_key=args.batch_sort_key, |
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
shortest_first=use_sortagrad, |
count=args.batch_count, |
batch_bins=args.batch_bins, |
batch_frames_in=args.batch_frames_in, |
batch_frames_out=args.batch_frames_out, |
batch_frames_inout=args.batch_frames_inout, |
swap_io=False, |
iaxis=0, |
oaxis=0, |
) |
valid_batchset = make_batchset( |
valid_json, |
args.batch_size, |
args.maxlen_in, |
args.maxlen_out, |
args.minibatches, |
batch_sort_key=args.batch_sort_key, |
min_batch_size=args.ngpu if args.ngpu > 1 else 1, |
count=args.batch_count, |
batch_bins=args.batch_bins, |
batch_frames_in=args.batch_frames_in, |
batch_frames_out=args.batch_frames_out, |
batch_frames_inout=args.batch_frames_inout, |
swap_io=False, |
iaxis=0, |
oaxis=0, |
) |
load_tr = LoadInputsAndTargets( |
mode="vc", |
use_speaker_embedding=args.use_speaker_embedding, |
use_second_target=args.use_second_target, |
preprocess_conf=args.preprocess_conf, |
preprocess_args={"train": True}, |
keep_all_data_on_mem=args.keep_all_data_on_mem, |
) |
load_cv = LoadInputsAndTargets( |
mode="vc", |
use_speaker_embedding=args.use_speaker_embedding, |
use_second_target=args.use_second_target, |
preprocess_conf=args.preprocess_conf, |
preprocess_args={"train": False}, |
keep_all_data_on_mem=args.keep_all_data_on_mem, |
) |
converter = CustomConverter() |
train_iter = { |
"main": ChainerDataLoader( |
dataset=TransformDataset( |
train_batchset, lambda data: converter([load_tr(data)]) |
), |
batch_size=1, |
num_workers=args.num_iter_processes, |
shuffle=not use_sortagrad, |
collate_fn=lambda x: x[0], |
) |
} |
valid_iter = { |
"main": ChainerDataLoader( |
dataset=TransformDataset( |
valid_batchset, lambda data: converter([load_cv(data)]) |
), |
batch_size=1, |
shuffle=False, |
collate_fn=lambda x: x[0], |
num_workers=args.num_iter_processes, |
) |
} |
updater = CustomUpdater( |
model, args.grad_clip, train_iter, optimizer, device, args.accum_grad |
) |
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir) |
if args.resume: |
logging.info("resumed from %s" % args.resume) |
torch_resume(args.resume, trainer) |
eval_interval = (args.eval_interval_epochs, "epoch") |
save_interval = (args.save_interval_epochs, "epoch") |
report_interval = (args.report_interval_iters, "iteration") |
trainer.extend( |
CustomEvaluator(model, valid_iter, reporter, device), trigger=eval_interval |
) |
trainer.extend(torch_snapshot(), trigger=save_interval) |
trainer.extend( |
snapshot_object(model, "model.loss.best"), |
trigger=training.triggers.MinValueTrigger( |
"validation/main/loss", trigger=eval_interval |
), |
) |
if args.num_save_attention > 0: |
data = sorted( |
list(valid_json.items())[: args.num_save_attention], |
key=lambda x: int(x[1]["input"][0]["shape"][1]), |
reverse=True, |
) |
if hasattr(model, "module"): |
att_vis_fn = model.module.calculate_all_attentions |
plot_class = model.module.attention_plot_class |
else: |
att_vis_fn = model.calculate_all_attentions |
plot_class = model.attention_plot_class |
att_reporter = plot_class( |
att_vis_fn, |
data, |
args.outdir + "/att_ws", |
converter=converter, |
transform=load_cv, |
device=device, |
reverse=True, |
) |
trainer.extend(att_reporter, trigger=eval_interval) |
else: |
att_reporter = None |
if hasattr(model, "module"): |
base_plot_keys = model.module.base_plot_keys |
else: |
base_plot_keys = model.base_plot_keys |
plot_keys = [] |
for key in base_plot_keys: |
plot_key = ["main/" + key, "validation/main/" + key] |
trainer.extend( |
extensions.PlotReport(plot_key, "epoch", file_name=key + ".png"), |
trigger=eval_interval, |
) |
plot_keys += plot_key |
trainer.extend( |
extensions.PlotReport(plot_keys, "epoch", file_name="all_loss.png"), |
trigger=eval_interval, |
) |
trainer.extend(extensions.LogReport(trigger=report_interval)) |
report_keys = ["epoch", "iteration", "elapsed_time"] + plot_keys |
trainer.extend(extensions.PrintReport(report_keys), trigger=report_interval) |
trainer.extend(extensions.ProgressBar(), trigger=report_interval) |
set_early_stop(trainer, args) |
if args.tensorboard_dir is not None and args.tensorboard_dir != "": |
writer = SummaryWriter(args.tensorboard_dir) |
trainer.extend(TensorboardLogger(writer, att_reporter), trigger=report_interval) |
if use_sortagrad: |
trainer.extend( |
ShufflingEnabler([train_iter]), |
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"), |
) |
trainer.run() |
check_early_stop(trainer, args.epochs) |
@torch.no_grad() |
def decode(args): |
"""Decode with E2E VC model.""" |
set_deterministic_pytorch(args) |
idim, odim, train_args = get_model_conf(args.model, args.model_conf) |
for key in sorted(vars(args).keys()): |
logging.info("args: " + key + ": " + str(vars(args)[key])) |
model_class = dynamic_import(train_args.model_module) |
model = model_class(idim, odim, train_args) |
assert isinstance(model, TTSInterface) |
logging.info(model) |
logging.info("reading model parameters from " + args.model) |
torch_load(args.model, model) |
model.eval() |
device = torch.device("cuda" if args.ngpu > 0 else "cpu") |
model = model.to(device) |
with open(args.json, "rb") as f: |
js = json.load(f)["utts"] |
outdir = os.path.dirname(args.out) |
if len(outdir) != 0 and not os.path.exists(outdir): |
os.makedirs(outdir) |
load_inputs_and_targets = LoadInputsAndTargets( |
mode="vc", |
load_output=False, |
sort_in_input_length=False, |
use_speaker_embedding=train_args.use_speaker_embedding, |
preprocess_conf=train_args.preprocess_conf |
if args.preprocess_conf is None |
else args.preprocess_conf, |
preprocess_args={"train": False}, |
) |
def _plot_and_save(array, figname, figsize=(6, 4), dpi=150): |
import matplotlib.pyplot as plt |
shape = array.shape |
if len(shape) == 1: |
plt.figure(figsize=figsize, dpi=dpi) |
plt.plot(array) |
plt.xlabel("Frame") |
plt.ylabel("Probability") |
plt.ylim([0, 1]) |
elif len(shape) == 2: |
plt.figure(figsize=figsize, dpi=dpi) |
plt.imshow(array, aspect="auto") |
plt.xlabel("Input") |
plt.ylabel("Output") |
elif len(shape) == 4: |
plt.figure(figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi) |
for idx1, xs in enumerate(array): |
for idx2, x in enumerate(xs, 1): |
plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2) |
plt.imshow(x, aspect="auto") |
plt.xlabel("Input") |
plt.ylabel("Output") |
else: |
raise NotImplementedError("Support only from 1D to 4D array.") |
plt.tight_layout() |
if not os.path.exists(os.path.dirname(figname)): |
os.makedirs(os.path.dirname(figname), exist_ok=True) |
plt.savefig(figname) |
plt.close() |
def _calculate_focus_rete(att_ws): |
if att_ws is None: |
return 1.0 |
elif len(att_ws.shape) == 2: |
return float(att_ws.max(dim=-1)[0].mean()) |
elif len(att_ws.shape) == 4: |
return float(att_ws.max(dim=-1)[0].mean(dim=-1).max()) |
else: |
raise ValueError("att_ws should be 2 or 4 dimensional tensor.") |
def _convert_att_to_duration(att_ws): |
if len(att_ws.shape) == 2: |
pass |
elif len(att_ws.shape) == 4: |
att_ws = torch.cat( |
[att_w for att_w in att_ws], dim=0 |
) |
diagonal_scores = att_ws.max(dim=-1)[0].mean(dim=-1) |
diagonal_head_idx = diagonal_scores.argmax() |
att_ws = att_ws[diagonal_head_idx] |
else: |
raise ValueError("att_ws should be 2 or 4 dimensional tensor.") |
durations = torch.stack( |
[att_ws.argmax(-1).eq(i).sum() for i in range(att_ws.shape[1])] |
) |
return durations.view(-1, 1).float() |
feat_writer = kaldiio.WriteHelper("ark,scp:{o}.ark,{o}.scp".format(o=args.out)) |
if args.save_durations: |
dur_writer = kaldiio.WriteHelper( |
"ark,scp:{o}.ark,{o}.scp".format(o=args.out.replace("feats", "durations")) |
) |
if args.save_focus_rates: |
fr_writer = kaldiio.WriteHelper( |
"ark,scp:{o}.ark,{o}.scp".format(o=args.out.replace("feats", "focus_rates")) |
) |
for idx, utt_id in enumerate(js.keys()): |
batch = [(utt_id, js[utt_id])] |
data = load_inputs_and_targets(batch) |
x = torch.FloatTensor(data[0][0]).to(device) |
spemb = None |
if train_args.use_speaker_embedding: |
spemb = torch.FloatTensor(data[1][0]).to(device) |
start_time = time.time() |
outs, probs, att_ws = model.inference(x, args, spemb=spemb) |
logging.info( |
"inference speed = %.1f frames / sec." |
% (int(outs.size(0)) / (time.time() - start_time)) |
) |
if outs.size(0) == x.size(0) * args.maxlenratio: |
logging.warning("output length reaches maximum length (%s)." % utt_id) |
focus_rate = _calculate_focus_rete(att_ws) |
logging.info( |
"(%d/%d) %s (size: %d->%d, focus rate: %.3f)" |
% (idx + 1, len(js.keys()), utt_id, x.size(0), outs.size(0), focus_rate) |
) |
feat_writer[utt_id] = outs.cpu().numpy() |
if args.save_durations: |
ds = _convert_att_to_duration(att_ws) |
dur_writer[utt_id] = ds.cpu().numpy() |
if args.save_focus_rates: |
fr_writer[utt_id] = np.array(focus_rate).reshape(1, 1) |
if probs is not None: |
_plot_and_save( |
probs.cpu().numpy(), |
os.path.dirname(args.out) + "/probs/%s_prob.png" % utt_id, |
) |
if att_ws is not None: |
_plot_and_save( |
att_ws.cpu().numpy(), |
os.path.dirname(args.out) + "/att_ws/%s_att_ws.png" % utt_id, |
) |
feat_writer.close() |
if args.save_durations: |
dur_writer.close() |
if args.save_focus_rates: |
fr_writer.close() |