import torch import torch.nn as nn import torch.nn.functional as F from module import * from commons import * import math class Generator(nn.Module): def __init__(self, n_vocab, h_c, f_c, f_c_dp, out_c, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6): super().__init__() self.encoder = Encoder(n_vocab, out_c, h_c, f_c, f_c_dp, heads= heads, layers = layers_enc, k_s = k_s) self.decoder = Decoder(in_c = out_c, hi_c = h_c, k_s = k_s_dec) def forward(self, x, x_lengths, y=None, y_lengths=None, gen = False, noise_scale=1., length_scale=1.): x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths) if gen: w = torch.exp(logw) * x_mask * length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length) z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2) attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask y, logdet = self.decoder(z, z_mask, reverse=True) return (y, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) else: y_max_length = y.size(2) y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length) z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2) z, logdet = self.decoder(y, z_mask, reverse=False) with torch.no_grad(): x_s_sq_r = torch.exp(-2 * x_logs) logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1] logp2 = torch.matmul(x_s_sq_r.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask return (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) def preprocess(self, y, y_lengths, y_max_length): if y_max_length is not None: y_max_length = (y_max_length // 2) * 2 y = y[:,:,:y_max_length] y_lengths = (y_lengths // 2) * 2 return y, y_lengths, y_max_length class Encoder(nn.Module): def __init__(self, n_vocab, out_c, h_c, f_c, f_c_dp, heads, layers, k_s, p=0.1, mean_only = True): super().__init__() self.h_c = h_c self.mean_only = mean_only self.emb = nn.Embedding(n_vocab, h_c) nn.init.normal_(self.emb.weight, 0.0, h_c**(-0.5)) self.prenet = Prenet(in_c = h_c, hi_c = h_c, out_c = h_c, k_s = 5) self.drop = nn.Dropout(p=p) self.atten_layers = nn.ModuleList() self.norm_layers = nn.ModuleList() self.ffn_layers = nn.ModuleList() for i in range(layers): self.atten_layers.append(MultiheadAttention(h_c, h_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None)) self.norm_layers.extend([Layernorm(h_c), Layernorm(h_c)]) self.ffn_layers.append(FFN(h_c, f_c, k_s, p)) self.proj_m = nn.Conv1d(h_c, out_c, 1) if not mean_only: self.proj_s = nn.Conv1d(h_c, out_c, 1) self.proj_w = DurationPredictor(h_c, f_c_dp, k_s, p=p) def forward(self,x, x_length): x = self.emb(x) * torch.sqrt(torch.tensor(self.h_c)) # [b,t,h] x = torch.transpose(x, 1, -1) # [b,h,t] x_mask = torch.unsqueeze(sequence_mask(x_length, x.size(2)), 1).to(x.dtype) x = self.prenet(x, x_mask) atten_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) for i in range(len(self.atten_layers)): x = x * x_mask y = self.drop(self.atten_layers[i](x, atten_mask)) x = self.norm_layers[2*i](x+y) y = self.drop(self.ffn_layers[i](x, x_mask)) x = self.norm_layers[2*i+1](x+y) x = x*x_mask x_m = self.proj_m(x) if not self.mean_only: x_logs = self.proj_m(x) else: x_logs = torch.zeros_like(x_m) logw = self.proj_w(x.detach(), x_mask) return x_m, x_logs, logw, x_mask class Decoder(nn.Module): def __init__(self, in_c, hi_c, k_s, d_l =1 , blocks = 12, splits = 4,): super().__init__() self.flows = nn.ModuleList() for _ in range(blocks): self.flows.extend([ActNorm(in_c*2), InvConvNear(splits = splits), Couplinglayer(in_c*2, hi_c, k_s, d_l = d_l)]) def forward(self, x, x_mask = None, reverse = False): if not reverse: flows = self.flows tot_logdet = 0 else: flows = reversed(self.flows) tot_logdet = None b, c, t = x.shape t = t - t%2 if x_mask is None: mask = torch.ones(b,1,t//2) else: mask = x_mask[:,:,1::2] x = x[:,:,:t].reshape(b, c, t//2, 2).transpose(2,3).contiguous().reshape(b,2*c,t//2) * mask # [b, 2c, t/2] for f in flows: x, logdet = f(x, mask, reverse = reverse) if not reverse: tot_logdet = tot_logdet + logdet if x_mask is None: mask = torch.ones(b,1,t) else: mask = x_mask[:,:,:t] x = x.reshape(b,c,2,t//2).transpose(2,3).contiguous().reshape(b,c,t) * mask # [b, c, t] return x, tot_logdet def skip(self): for f in self.flows: f.skip() def ddi_init(self): for i, f in enumerate(self.flows): if i % 3 == 0: f.set_ddi()