import torch import torch.nn as nn import torch.nn.functional as F ######################################### encoder ############################################## class Layernorm(nn.Module): def __init__(self, channels): super().__init__() self.gamma = nn.Parameter(torch.ones(1, channels)) self.beta = nn.Parameter(torch.zeros(1, channels)) def forward(self, x): m = torch.mean(x, dim = 1, keepdim = True) v = torch.mean((x-m)**2, dim = 1, keepdim = True) x = (x - m) * torch.rsqrt(v + 1e-4) # normarlization n = len(x.shape) shape = [1, -1] + [1]*(n-2) x = x*self.gamma.reshape(*shape) + self.beta.reshape(*shape) return x class Prenet(nn.Module): def __init__(self, in_c, hi_c, out_c, k_s = 5, layers =3, p = 0.05): super().__init__() self.crn = nn.ModuleList() self.crn.extend([nn.Conv1d(in_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)]) self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)]) self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)]) self.proj = nn.Conv1d(hi_c, out_c, 1) self.proj.weight.data.zero_() self.proj.bias.data.zero_() def forward(self, start, x_mask=1): x = start for layer in self.crn: x = layer(x) # [b. c. t] x = x * x_mask x = self.proj(x) + start # [b. c. t] end = x * x_mask return end # [b. c. t] class MultiheadAttention(nn.Module): def __init__(self, c, out_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None,): super().__init__() self.k = c // heads self.window_size = window_size self.proj_q = nn.Conv1d(c,c,1) self.proj_k = nn.Conv1d(c,c,1) self.proj_v = nn.Conv1d(c,c,1) nn.init.xavier_uniform_(self.proj_q.weight) nn.init.xavier_uniform_(self.proj_k.weight) nn.init.xavier_uniform_(self.proj_v.weight) n_heads_rel = 1 if heads_share else heads self.d_k = (self.k)**(-0.5) self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k) self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k) self.conv_o = nn.Conv1d(c, out_c, 1) self.drop = nn.Dropout(p=p) def forward(self, x, attn_mask=None): query, key, value = self.proj_q(x), self.proj_k(x), self.proj_v(x) b, c, t = query.shape h, k = c // self.k, self.k query = query.reshape(b,h,k,t) key = key.reshape(b,h,k,t) value = value.reshape(b,h,k,t) matrix = self.get_relative_matrix(self.emb_rel_k, t) rel_logit = torch.matmul(matrix.unsqueeze(0), query) # [1,1,2t-1,k] * [b,h,k,t] = [b,h,2t-1,t] abs_logit = self.rel_to_abs(rel_logit.transpose(2,3)) local_score = abs_logit * self.d_k score = torch.matmul(query.transpose(2,3), key) * self.d_k + local_score if attn_mask is not None: score = score.masked_fill(attn_mask == 0, -1e4) align = F.softmax(score, dim = -1) atten = self.drop(align) self.atten = atten matrix = self.get_relative_matrix(self.emb_rel_v, t).transpose(1,2) # [1,k,2t-1] weight = self.abs_to_rel(atten).transpose(2,3) # [b,h,2t-1,t] output = torch.matmul(value, atten) + torch.matmul(matrix.unsqueeze(0), weight) # [b,h,k,t] x = self.conv_o(output.contiguous().reshape(b,c,t)) return x def get_relative_matrix(self, emb_rel_k, t): s = self.window_size pad_size = max(t - s - 1, 0) start = max(s+1-t, 0) emb_rel_k = F.pad(emb_rel_k, (0,0, pad_size, pad_size)) return emb_rel_k[:,start:start+2*t+1] def rel_to_abs(self, x): b,h,t,_= x.shape x = F.pad(x, (0,1)).reshape(b,h,2*t*t) x = F.pad(x, (0,t-1)).reshape(b,h,t+1, 2*t-1)[:,:,:t,t-1:] return x def abs_to_rel(self, x): b,h,t,t = x.shape x = F.pad(x, (0, t-1)).reshape(b,h,2*t*t-t) x = F.pad(x, (t,0)).reshape(b,h,t,2*t)[:,:,:,1:] return x class FFN(nn.Module): def __init__(self, h_c, f_c, k_s, p = 0.1): super().__init__() self.conv1 = nn.Conv1d(h_c, f_c, k_s, padding=k_s//2) self.conv2 = nn.Conv1d(f_c, h_c, k_s, padding=k_s//2) self.drop = nn.Dropout(p=p) def forward(self, x, x_mask = None): x = self.conv2(self.drop(F.relu(self.conv1(x*x_mask)))*x_mask) return x * x_mask class DurationPredictor(nn.Module): def __init__(self, in_c, f_c, k_s, p=0.1): super().__init__() self.block1 = nn.Sequential(nn.Conv1d(in_c, f_c, k_s, padding=k_s//2), nn.ReLU(), Layernorm(f_c), nn.Dropout(p=p)) self.block2 = nn.Sequential(nn.Conv1d(f_c, f_c, k_s, padding=k_s//2), nn.ReLU(), Layernorm(f_c), nn.Dropout(p=p)) self.proj = nn.Conv1d(f_c, 1, 1) def forward(self, x, x_mask): x = self.block1(x * x_mask) x = self.block2(x * x_mask) x = self.proj(x * x_mask) return x * x_mask ######################################### decoder ############################################## # static file system(reasoning the type of tensor), optimizing computation graph, complie before functioning >> to accelate the speed @torch.jit.script def fuse_tan_sig_add(x:torch.Tensor, mid:int) -> torch.Tensor: a, b = x[:, :mid, :], x[:, mid:, :] return torch.sigmoid(a) * torch.tanh(b) class WN(nn.Module): # non-casual wavenet without dilation def __init__(self, hi_c, k_s, d_l = 1, layers = 3, p=0.05): super().__init__() self.hi_c = hi_c self.resblocks=nn.ModuleList() self.skipblocks=nn.ModuleList() self.drop = nn.Dropout(p=p) for _ in range(layers): res_layer = nn.Conv1d(hi_c, 2*hi_c, k_s, dilation=d_l, padding=k_s//2) res_layer = nn.utils.weight_norm(res_layer, name = 'weight') self.resblocks.append(res_layer) if _ ==2: skip_layer = nn.Conv1d(hi_c, hi_c, 1) # last layer else: skip_layer = nn.Conv1d(hi_c, 2*hi_c, 1) skip_layer = nn.utils.weight_norm(skip_layer, name = 'weight') self.skipblocks.append(skip_layer) def forward(self, x, x_mask = None): mid = self.hi_c end = torch.zeros_like(x, dtype=x.dtype) for i in range(len(self.resblocks)): x = self.drop(self.resblocks[i](x)) # [b, 2c, t] x = fuse_tan_sig_add(x, mid) # [b, c, t] y = self.skipblocks[i](x) if i == 2: end = end + y # last layer else: x = (x + y[:, :mid, :]) * x_mask end = end + y[:, mid:, :] return end * x_mask def skip(self): for layer1, layer2 in zip(self.resblocks, self.skipblocks): nn.utils.remove_weight_norm(layer1) nn.utils.remove_weight_norm(layer2) class Couplinglayer(nn.Module): def __init__(self, in_c, hi_c, k_s, d_l = 1): super().__init__() s_proj = nn.Conv1d(in_c//2, hi_c, 1) self.start = nn.utils.weight_norm(s_proj, name = 'weight') # Initializing last layer to 0 makes the affine coupling layers # do nothing at first. It helps to stabilze training. from glow paper self.end = nn.Conv1d(hi_c, in_c, 1) self.end.weight.data.zero_() self.end.bias.data.zero_() self.wn = WN(hi_c, k_s, d_l) # y = x * logs + t def forward(self, x, x_mask=None, reverse = False): if x_mask is None: x_mask = 1 mid = x.shape[1]//2 # divide channels by 2 x_0, x_1 = x[:, :mid, :], x[:, mid:, :] z_1 = self.end(self.wn(self.start(x_1) * x_mask, x_mask)) logs, t = z_1[:,mid:,:], z_1[:, :mid, :] if reverse: x_0 = torch.exp(-logs)*(x_0 - t) * x_mask logdet = None else : x_0 = torch.exp(logs + 1e-4) * x_0 + t logdet = torch.sum(logs * x_mask, [1,2]) # sum(log(s)) z = torch.cat([x_0, x_1], dim = 1) return z, logdet def skip(self): self.wn.skip() class InvConvNear(nn.Module): def __init__(self, splits = 4): super().__init__() self.splits = splits w_init = torch.linalg.qr(torch.randn((splits, splits)).normal_())[0] # othonormal vector matrix if torch.det(w_init) < 0: w_init[0,:] = -w_init[0,:] self.weight = nn.Parameter(w_init) def forward(self, x, x_mask=None, reverse = False): b, c, t = x.shape if x_mask is None: x_mask = 1 x_len = torch.ones(b) * t # [b] else: x_len = torch.sum(x_mask, [1,2]) s = self.splits x = x.reshape(b, 2, c//s, s//2, t) # split channels into 2 groups x = x.permute(0,1,3,2,4).contiguous().reshape(b, s, c//s, t) if reverse: if hasattr(self, "weight_inv"): weight = self.weight_inv weight = torch.inverse(self.weight).to(dtype=self.weight.dtype) logdet = None else: weight = self.weight logdet = torch.logdet(weight) * (c//s) * x_len # h*w*log(det(W)) since there's no necesserity for decomposition weight = weight.unsqueeze(-1).unsqueeze(-1) z = F.conv2d(x, weight) # z = matmul(weight, x_i,j) for i,j in h = c//s, w = t z = z.reshape(b, 2, s//2, c//s, t).permute(0,1,3,2,4).contiguous().reshape(b, c, t) * x_mask return z, logdet def skip(self): self.weigth_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) class ActNorm(nn.Module): def __init__(self, hi_c, ddi = False): # data dependent initialization super().__init__() self.logs = nn.Parameter(torch.zeros(1, hi_c, 1)) self.bias = nn.Parameter(torch.zeros(1, hi_c, 1)) self.ddi = ddi def forward(self, x, x_mask = None, reverse = False): b, _, t = x.shape if x_mask is None: x_mask = torch.ones(b,1,t).to(device= x.device, dtype = x.dtype) x_len = torch.sum(x_mask, [1, 2]) if self.ddi: self.initialize(x, x_mask) self.ddi = False # y = exp(logs) * x + bias > normalization in channel dim if reverse: z = (x - self.bias) * torch.exp(-self.logs) * x_mask logdet = None else: z = (torch.exp(self.logs) * x + self.bias) * x_mask logdet = torch.sum(self.logs, [1,2])* x_len return z, logdet def initialize(self, x, x_mask): with torch.no_grad(): n = torch.sum(x_mask, [0,2]) m = torch.sum(x * x_mask, [0,2])/n m_s = torch.sum(x * x * x_mask, [0,2])/n v = m_s - m**2 logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) init_bias = (-m/torch.exp(-logs)).reshape(*self.bias.shape).to(dtype = self.bias.dtype) # -m/s init_logs = (-logs).reshape(*self.logs.shape).to(dtype = self.logs.dtype) # -logs self.bias.data.copy_(init_bias) self.logs.data.copy_(init_logs) def set_ddi(self): self.ddi = True def skip(self): pass