Glow-HiFi-TTS / model.py
marigold334's picture
Upload 10 files
41989ff
raw
history blame
No virus
7.08 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from module import *
from commons import *
import math
class Generator(nn.Module):
def __init__(self, n_vocab, h_c, f_c, f_c_dp, out_c, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6):
super().__init__()
self.encoder = Encoder(n_vocab, out_c, h_c, f_c, f_c_dp, heads= heads, layers = layers_enc, k_s = k_s)
self.decoder = Decoder(in_c = out_c, hi_c = h_c, k_s = k_s_dec)
def forward(self, x, x_lengths, y=None, y_lengths=None, gen = False, noise_scale=1., length_scale=1.):
x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths)
if gen:
w = torch.exp(logw) * x_mask * length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask
y, logdet = self.decoder(z, z_mask, reverse=True)
return (y, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)
else:
y_max_length = y.size(2)
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
z, logdet = self.decoder(y, z_mask, reverse=False)
with torch.no_grad():
x_s_sq_r = torch.exp(-2 * x_logs)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(x_s_sq_r.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
return (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)
def preprocess(self, y, y_lengths, y_max_length):
if y_max_length is not None:
y_max_length = (y_max_length // 2) * 2
y = y[:,:,:y_max_length]
y_lengths = (y_lengths // 2) * 2
return y, y_lengths, y_max_length
class Encoder(nn.Module):
def __init__(self, n_vocab, out_c, h_c, f_c, f_c_dp, heads, layers, k_s, p=0.1, mean_only = True):
super().__init__()
self.h_c = h_c
self.mean_only = mean_only
self.emb = nn.Embedding(n_vocab, h_c)
nn.init.normal_(self.emb.weight, 0.0, h_c**(-0.5))
self.prenet = Prenet(in_c = h_c, hi_c = h_c, out_c = h_c, k_s = 5)
self.drop = nn.Dropout(p=p)
self.atten_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.ffn_layers = nn.ModuleList()
for i in range(layers):
self.atten_layers.append(MultiheadAttention(h_c, h_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None))
self.norm_layers.extend([Layernorm(h_c), Layernorm(h_c)])
self.ffn_layers.append(FFN(h_c, f_c, k_s, p))
self.proj_m = nn.Conv1d(h_c, out_c, 1)
if not mean_only:
self.proj_s = nn.Conv1d(h_c, out_c, 1)
self.proj_w = DurationPredictor(h_c, f_c_dp, k_s, p=p)
def forward(self,x, x_length):
x = self.emb(x) * torch.sqrt(torch.tensor(self.h_c)) # [b,t,h]
x = torch.transpose(x, 1, -1) # [b,h,t]
x_mask = torch.unsqueeze(sequence_mask(x_length, x.size(2)), 1).to(x.dtype)
x = self.prenet(x, x_mask)
atten_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
for i in range(len(self.atten_layers)):
x = x * x_mask
y = self.drop(self.atten_layers[i](x, atten_mask))
x = self.norm_layers[2*i](x+y)
y = self.drop(self.ffn_layers[i](x, x_mask))
x = self.norm_layers[2*i+1](x+y)
x = x*x_mask
x_m = self.proj_m(x)
if not self.mean_only:
x_logs = self.proj_m(x)
else:
x_logs = torch.zeros_like(x_m)
logw = self.proj_w(x.detach(), x_mask)
return x_m, x_logs, logw, x_mask
class Decoder(nn.Module):
def __init__(self, in_c, hi_c, k_s, d_l =1 , blocks = 12, splits = 4,):
super().__init__()
self.flows = nn.ModuleList()
for _ in range(blocks):
self.flows.extend([ActNorm(in_c*2), InvConvNear(splits = splits), Couplinglayer(in_c*2, hi_c, k_s, d_l = d_l)])
def forward(self, x, x_mask = None, reverse = False):
if not reverse:
flows = self.flows
tot_logdet = 0
else:
flows = reversed(self.flows)
tot_logdet = None
b, c, t = x.shape
t = t - t%2
if x_mask is None:
mask = torch.ones(b,1,t//2)
else:
mask = x_mask[:,:,1::2]
x = x[:,:,:t].reshape(b, c, t//2, 2).transpose(2,3).contiguous().reshape(b,2*c,t//2) * mask # [b, 2c, t/2]
for f in flows:
x, logdet = f(x, mask, reverse = reverse)
if not reverse:
tot_logdet = tot_logdet + logdet
if x_mask is None:
mask = torch.ones(b,1,t)
else:
mask = x_mask[:,:,:t]
x = x.reshape(b,c,2,t//2).transpose(2,3).contiguous().reshape(b,c,t) * mask # [b, c, t]
return x, tot_logdet
def skip(self):
for f in self.flows:
f.skip()
def ddi_init(self):
for i, f in enumerate(self.flows):
if i % 3 == 0:
f.set_ddi()