Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from module import * | |
from commons import * | |
import math | |
class Generator(nn.Module): | |
def __init__(self, n_vocab, h_c, f_c, f_c_dp, out_c, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6): | |
super().__init__() | |
self.encoder = Encoder(n_vocab, out_c, h_c, f_c, f_c_dp, heads= heads, layers = layers_enc, k_s = k_s) | |
self.decoder = Decoder(in_c = out_c, hi_c = h_c, k_s = k_s_dec) | |
def forward(self, x, x_lengths, y=None, y_lengths=None, gen = False, noise_scale=1., length_scale=1.): | |
x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths) | |
if gen: | |
w = torch.exp(logw) * x_mask * length_scale | |
w_ceil = torch.ceil(w) | |
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() | |
y_max_length = None | |
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length) | |
z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) | |
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2) | |
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) | |
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] | |
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] | |
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask | |
z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask | |
y, logdet = self.decoder(z, z_mask, reverse=True) | |
return (y, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) | |
else: | |
y_max_length = y.size(2) | |
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length) | |
z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) | |
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2) | |
z, logdet = self.decoder(y, z_mask, reverse=False) | |
with torch.no_grad(): | |
x_s_sq_r = torch.exp(-2 * x_logs) | |
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1] | |
logp2 = torch.matmul(x_s_sq_r.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] | |
logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] | |
logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(-1) # [b, t, 1] | |
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] | |
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() | |
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] | |
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] | |
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask | |
return (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_) | |
def preprocess(self, y, y_lengths, y_max_length): | |
if y_max_length is not None: | |
y_max_length = (y_max_length // 2) * 2 | |
y = y[:,:,:y_max_length] | |
y_lengths = (y_lengths // 2) * 2 | |
return y, y_lengths, y_max_length | |
class Encoder(nn.Module): | |
def __init__(self, n_vocab, out_c, h_c, f_c, f_c_dp, heads, layers, k_s, p=0.1, mean_only = True): | |
super().__init__() | |
self.h_c = h_c | |
self.mean_only = mean_only | |
self.emb = nn.Embedding(n_vocab, h_c) | |
nn.init.normal_(self.emb.weight, 0.0, h_c**(-0.5)) | |
self.prenet = Prenet(in_c = h_c, hi_c = h_c, out_c = h_c, k_s = 5) | |
self.drop = nn.Dropout(p=p) | |
self.atten_layers = nn.ModuleList() | |
self.norm_layers = nn.ModuleList() | |
self.ffn_layers = nn.ModuleList() | |
for i in range(layers): | |
self.atten_layers.append(MultiheadAttention(h_c, h_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None)) | |
self.norm_layers.extend([Layernorm(h_c), Layernorm(h_c)]) | |
self.ffn_layers.append(FFN(h_c, f_c, k_s, p)) | |
self.proj_m = nn.Conv1d(h_c, out_c, 1) | |
if not mean_only: | |
self.proj_s = nn.Conv1d(h_c, out_c, 1) | |
self.proj_w = DurationPredictor(h_c, f_c_dp, k_s, p=p) | |
def forward(self,x, x_length): | |
x = self.emb(x) * torch.sqrt(torch.tensor(self.h_c)) # [b,t,h] | |
x = torch.transpose(x, 1, -1) # [b,h,t] | |
x_mask = torch.unsqueeze(sequence_mask(x_length, x.size(2)), 1).to(x.dtype) | |
x = self.prenet(x, x_mask) | |
atten_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) | |
for i in range(len(self.atten_layers)): | |
x = x * x_mask | |
y = self.drop(self.atten_layers[i](x, atten_mask)) | |
x = self.norm_layers[2*i](x+y) | |
y = self.drop(self.ffn_layers[i](x, x_mask)) | |
x = self.norm_layers[2*i+1](x+y) | |
x = x*x_mask | |
x_m = self.proj_m(x) | |
if not self.mean_only: | |
x_logs = self.proj_m(x) | |
else: | |
x_logs = torch.zeros_like(x_m) | |
logw = self.proj_w(x.detach(), x_mask) | |
return x_m, x_logs, logw, x_mask | |
class Decoder(nn.Module): | |
def __init__(self, in_c, hi_c, k_s, d_l =1 , blocks = 12, splits = 4,): | |
super().__init__() | |
self.flows = nn.ModuleList() | |
for _ in range(blocks): | |
self.flows.extend([ActNorm(in_c*2), InvConvNear(splits = splits), Couplinglayer(in_c*2, hi_c, k_s, d_l = d_l)]) | |
def forward(self, x, x_mask = None, reverse = False): | |
if not reverse: | |
flows = self.flows | |
tot_logdet = 0 | |
else: | |
flows = reversed(self.flows) | |
tot_logdet = None | |
b, c, t = x.shape | |
t = t - t%2 | |
if x_mask is None: | |
mask = torch.ones(b,1,t//2) | |
else: | |
mask = x_mask[:,:,1::2] | |
x = x[:,:,:t].reshape(b, c, t//2, 2).transpose(2,3).contiguous().reshape(b,2*c,t//2) * mask # [b, 2c, t/2] | |
for f in flows: | |
x, logdet = f(x, mask, reverse = reverse) | |
if not reverse: | |
tot_logdet = tot_logdet + logdet | |
if x_mask is None: | |
mask = torch.ones(b,1,t) | |
else: | |
mask = x_mask[:,:,:t] | |
x = x.reshape(b,c,2,t//2).transpose(2,3).contiguous().reshape(b,c,t) * mask # [b, c, t] | |
return x, tot_logdet | |
def skip(self): | |
for f in self.flows: | |
f.skip() | |
def ddi_init(self): | |
for i, f in enumerate(self.flows): | |
if i % 3 == 0: | |
f.set_ddi() | |