Glow-HiFi-TTS / module.py
marigold334's picture
Upload 10 files
41989ff
raw
history blame
12 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
######################################### encoder ##############################################
class Layernorm(nn.Module):
def __init__(self, channels):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, channels))
self.beta = nn.Parameter(torch.zeros(1, channels))
def forward(self, x):
m = torch.mean(x, dim = 1, keepdim = True)
v = torch.mean((x-m)**2, dim = 1, keepdim = True)
x = (x - m) * torch.rsqrt(v + 1e-4) # normarlization
n = len(x.shape)
shape = [1, -1] + [1]*(n-2)
x = x*self.gamma.reshape(*shape) + self.beta.reshape(*shape)
return x
class Prenet(nn.Module):
def __init__(self, in_c, hi_c, out_c, k_s = 5, layers =3, p = 0.05):
super().__init__()
self.crn = nn.ModuleList()
self.crn.extend([nn.Conv1d(in_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
self.proj = nn.Conv1d(hi_c, out_c, 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
def forward(self, start, x_mask=1):
x = start
for layer in self.crn:
x = layer(x) # [b. c. t]
x = x * x_mask
x = self.proj(x) + start # [b. c. t]
end = x * x_mask
return end # [b. c. t]
class MultiheadAttention(nn.Module):
def __init__(self, c, out_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None,):
super().__init__()
self.k = c // heads
self.window_size = window_size
self.proj_q = nn.Conv1d(c,c,1)
self.proj_k = nn.Conv1d(c,c,1)
self.proj_v = nn.Conv1d(c,c,1)
nn.init.xavier_uniform_(self.proj_q.weight)
nn.init.xavier_uniform_(self.proj_k.weight)
nn.init.xavier_uniform_(self.proj_v.weight)
n_heads_rel = 1 if heads_share else heads
self.d_k = (self.k)**(-0.5)
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k)
self.conv_o = nn.Conv1d(c, out_c, 1)
self.drop = nn.Dropout(p=p)
def forward(self, x, attn_mask=None):
query, key, value = self.proj_q(x), self.proj_k(x), self.proj_v(x)
b, c, t = query.shape
h, k = c // self.k, self.k
query = query.reshape(b,h,k,t)
key = key.reshape(b,h,k,t)
value = value.reshape(b,h,k,t)
matrix = self.get_relative_matrix(self.emb_rel_k, t)
rel_logit = torch.matmul(matrix.unsqueeze(0), query) # [1,1,2t-1,k] * [b,h,k,t] = [b,h,2t-1,t]
abs_logit = self.rel_to_abs(rel_logit.transpose(2,3))
local_score = abs_logit * self.d_k
score = torch.matmul(query.transpose(2,3), key) * self.d_k + local_score
if attn_mask is not None:
score = score.masked_fill(attn_mask == 0, -1e4)
align = F.softmax(score, dim = -1)
atten = self.drop(align)
self.atten = atten
matrix = self.get_relative_matrix(self.emb_rel_v, t).transpose(1,2) # [1,k,2t-1]
weight = self.abs_to_rel(atten).transpose(2,3) # [b,h,2t-1,t]
output = torch.matmul(value, atten) + torch.matmul(matrix.unsqueeze(0), weight) # [b,h,k,t]
x = self.conv_o(output.contiguous().reshape(b,c,t))
return x
def get_relative_matrix(self, emb_rel_k, t):
s = self.window_size
pad_size = max(t - s - 1, 0)
start = max(s+1-t, 0)
emb_rel_k = F.pad(emb_rel_k, (0,0, pad_size, pad_size))
return emb_rel_k[:,start:start+2*t+1]
def rel_to_abs(self, x):
b,h,t,_= x.shape
x = F.pad(x, (0,1)).reshape(b,h,2*t*t)
x = F.pad(x, (0,t-1)).reshape(b,h,t+1, 2*t-1)[:,:,:t,t-1:]
return x
def abs_to_rel(self, x):
b,h,t,t = x.shape
x = F.pad(x, (0, t-1)).reshape(b,h,2*t*t-t)
x = F.pad(x, (t,0)).reshape(b,h,t,2*t)[:,:,:,1:]
return x
class FFN(nn.Module):
def __init__(self, h_c, f_c, k_s, p = 0.1):
super().__init__()
self.conv1 = nn.Conv1d(h_c, f_c, k_s, padding=k_s//2)
self.conv2 = nn.Conv1d(f_c, h_c, k_s, padding=k_s//2)
self.drop = nn.Dropout(p=p)
def forward(self, x, x_mask = None):
x = self.conv2(self.drop(F.relu(self.conv1(x*x_mask)))*x_mask)
return x * x_mask
class DurationPredictor(nn.Module):
def __init__(self, in_c, f_c, k_s, p=0.1):
super().__init__()
self.block1 = nn.Sequential(nn.Conv1d(in_c, f_c, k_s, padding=k_s//2),
nn.ReLU(),
Layernorm(f_c),
nn.Dropout(p=p))
self.block2 = nn.Sequential(nn.Conv1d(f_c, f_c, k_s, padding=k_s//2),
nn.ReLU(),
Layernorm(f_c),
nn.Dropout(p=p))
self.proj = nn.Conv1d(f_c, 1, 1)
def forward(self, x, x_mask):
x = self.block1(x * x_mask)
x = self.block2(x * x_mask)
x = self.proj(x * x_mask)
return x * x_mask
######################################### decoder ##############################################
# static file system(reasoning the type of tensor), optimizing computation graph, complie before functioning >> to accelate the speed
@torch.jit.script
def fuse_tan_sig_add(x:torch.Tensor, mid:int) -> torch.Tensor:
a, b = x[:, :mid, :], x[:, mid:, :]
return torch.sigmoid(a) * torch.tanh(b)
class WN(nn.Module): # non-casual wavenet without dilation
def __init__(self, hi_c, k_s, d_l = 1, layers = 3, p=0.05):
super().__init__()
self.hi_c = hi_c
self.resblocks=nn.ModuleList()
self.skipblocks=nn.ModuleList()
self.drop = nn.Dropout(p=p)
for _ in range(layers):
res_layer = nn.Conv1d(hi_c, 2*hi_c, k_s, dilation=d_l, padding=k_s//2)
res_layer = nn.utils.weight_norm(res_layer, name = 'weight')
self.resblocks.append(res_layer)
if _ ==2:
skip_layer = nn.Conv1d(hi_c, hi_c, 1) # last layer
else:
skip_layer = nn.Conv1d(hi_c, 2*hi_c, 1)
skip_layer = nn.utils.weight_norm(skip_layer, name = 'weight')
self.skipblocks.append(skip_layer)
def forward(self, x, x_mask = None):
mid = self.hi_c
end = torch.zeros_like(x, dtype=x.dtype)
for i in range(len(self.resblocks)):
x = self.drop(self.resblocks[i](x)) # [b, 2c, t]
x = fuse_tan_sig_add(x, mid) # [b, c, t]
y = self.skipblocks[i](x)
if i == 2:
end = end + y # last layer
else:
x = (x + y[:, :mid, :]) * x_mask
end = end + y[:, mid:, :]
return end * x_mask
def skip(self):
for layer1, layer2 in zip(self.resblocks, self.skipblocks):
nn.utils.remove_weight_norm(layer1)
nn.utils.remove_weight_norm(layer2)
class Couplinglayer(nn.Module):
def __init__(self, in_c, hi_c, k_s, d_l = 1):
super().__init__()
s_proj = nn.Conv1d(in_c//2, hi_c, 1)
self.start = nn.utils.weight_norm(s_proj, name = 'weight')
# Initializing last layer to 0 makes the affine coupling layers
# do nothing at first. It helps to stabilze training. from glow paper
self.end = nn.Conv1d(hi_c, in_c, 1)
self.end.weight.data.zero_()
self.end.bias.data.zero_()
self.wn = WN(hi_c, k_s, d_l)
# y = x * logs + t
def forward(self, x, x_mask=None, reverse = False):
if x_mask is None:
x_mask = 1
mid = x.shape[1]//2 # divide channels by 2
x_0, x_1 = x[:, :mid, :], x[:, mid:, :]
z_1 = self.end(self.wn(self.start(x_1) * x_mask, x_mask))
logs, t = z_1[:,mid:,:], z_1[:, :mid, :]
if reverse:
x_0 = torch.exp(-logs)*(x_0 - t) * x_mask
logdet = None
else :
x_0 = torch.exp(logs + 1e-4) * x_0 + t
logdet = torch.sum(logs * x_mask, [1,2]) # sum(log(s))
z = torch.cat([x_0, x_1], dim = 1)
return z, logdet
def skip(self):
self.wn.skip()
class InvConvNear(nn.Module):
def __init__(self, splits = 4):
super().__init__()
self.splits = splits
w_init = torch.linalg.qr(torch.randn((splits, splits)).normal_())[0] # othonormal vector matrix
if torch.det(w_init) < 0:
w_init[0,:] = -w_init[0,:]
self.weight = nn.Parameter(w_init)
def forward(self, x, x_mask=None, reverse = False):
b, c, t = x.shape
if x_mask is None:
x_mask = 1
x_len = torch.ones(b) * t # [b]
else:
x_len = torch.sum(x_mask, [1,2])
s = self.splits
x = x.reshape(b, 2, c//s, s//2, t) # split channels into 2 groups
x = x.permute(0,1,3,2,4).contiguous().reshape(b, s, c//s, t)
if reverse:
if hasattr(self, "weight_inv"):
weight = self.weight_inv
weight = torch.inverse(self.weight).to(dtype=self.weight.dtype)
logdet = None
else:
weight = self.weight
logdet = torch.logdet(weight) * (c//s) * x_len # h*w*log(det(W)) since there's no necesserity for decomposition
weight = weight.unsqueeze(-1).unsqueeze(-1)
z = F.conv2d(x, weight) # z = matmul(weight, x_i,j) for i,j in h = c//s, w = t
z = z.reshape(b, 2, s//2, c//s, t).permute(0,1,3,2,4).contiguous().reshape(b, c, t) * x_mask
return z, logdet
def skip(self):
self.weigth_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
class ActNorm(nn.Module):
def __init__(self, hi_c, ddi = False): # data dependent initialization
super().__init__()
self.logs = nn.Parameter(torch.zeros(1, hi_c, 1))
self.bias = nn.Parameter(torch.zeros(1, hi_c, 1))
self.ddi = ddi
def forward(self, x, x_mask = None, reverse = False):
b, _, t = x.shape
if x_mask is None:
x_mask = torch.ones(b,1,t).to(device= x.device, dtype = x.dtype)
x_len = torch.sum(x_mask, [1, 2])
if self.ddi:
self.initialize(x, x_mask)
self.ddi = False
# y = exp(logs) * x + bias > normalization in channel dim
if reverse:
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
logdet = None
else:
z = (torch.exp(self.logs) * x + self.bias) * x_mask
logdet = torch.sum(self.logs, [1,2])* x_len
return z, logdet
def initialize(self, x, x_mask):
with torch.no_grad():
n = torch.sum(x_mask, [0,2])
m = torch.sum(x * x_mask, [0,2])/n
m_s = torch.sum(x * x * x_mask, [0,2])/n
v = m_s - m**2
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
init_bias = (-m/torch.exp(-logs)).reshape(*self.bias.shape).to(dtype = self.bias.dtype) # -m/s
init_logs = (-logs).reshape(*self.logs.shape).to(dtype = self.logs.dtype) # -logs
self.bias.data.copy_(init_bias)
self.logs.data.copy_(init_logs)
def set_ddi(self):
self.ddi = True
def skip(self):
pass