Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
######################################### encoder ############################################## | |
class Layernorm(nn.Module): | |
def __init__(self, channels): | |
super().__init__() | |
self.gamma = nn.Parameter(torch.ones(1, channels)) | |
self.beta = nn.Parameter(torch.zeros(1, channels)) | |
def forward(self, x): | |
m = torch.mean(x, dim = 1, keepdim = True) | |
v = torch.mean((x-m)**2, dim = 1, keepdim = True) | |
x = (x - m) * torch.rsqrt(v + 1e-4) # normarlization | |
n = len(x.shape) | |
shape = [1, -1] + [1]*(n-2) | |
x = x*self.gamma.reshape(*shape) + self.beta.reshape(*shape) | |
return x | |
class Prenet(nn.Module): | |
def __init__(self, in_c, hi_c, out_c, k_s = 5, layers =3, p = 0.05): | |
super().__init__() | |
self.crn = nn.ModuleList() | |
self.crn.extend([nn.Conv1d(in_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)]) | |
self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)]) | |
self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)]) | |
self.proj = nn.Conv1d(hi_c, out_c, 1) | |
self.proj.weight.data.zero_() | |
self.proj.bias.data.zero_() | |
def forward(self, start, x_mask=1): | |
x = start | |
for layer in self.crn: | |
x = layer(x) # [b. c. t] | |
x = x * x_mask | |
x = self.proj(x) + start # [b. c. t] | |
end = x * x_mask | |
return end # [b. c. t] | |
class MultiheadAttention(nn.Module): | |
def __init__(self, c, out_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None,): | |
super().__init__() | |
self.k = c // heads | |
self.window_size = window_size | |
self.proj_q = nn.Conv1d(c,c,1) | |
self.proj_k = nn.Conv1d(c,c,1) | |
self.proj_v = nn.Conv1d(c,c,1) | |
nn.init.xavier_uniform_(self.proj_q.weight) | |
nn.init.xavier_uniform_(self.proj_k.weight) | |
nn.init.xavier_uniform_(self.proj_v.weight) | |
n_heads_rel = 1 if heads_share else heads | |
self.d_k = (self.k)**(-0.5) | |
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k) | |
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k) | |
self.conv_o = nn.Conv1d(c, out_c, 1) | |
self.drop = nn.Dropout(p=p) | |
def forward(self, x, attn_mask=None): | |
query, key, value = self.proj_q(x), self.proj_k(x), self.proj_v(x) | |
b, c, t = query.shape | |
h, k = c // self.k, self.k | |
query = query.reshape(b,h,k,t) | |
key = key.reshape(b,h,k,t) | |
value = value.reshape(b,h,k,t) | |
matrix = self.get_relative_matrix(self.emb_rel_k, t) | |
rel_logit = torch.matmul(matrix.unsqueeze(0), query) # [1,1,2t-1,k] * [b,h,k,t] = [b,h,2t-1,t] | |
abs_logit = self.rel_to_abs(rel_logit.transpose(2,3)) | |
local_score = abs_logit * self.d_k | |
score = torch.matmul(query.transpose(2,3), key) * self.d_k + local_score | |
if attn_mask is not None: | |
score = score.masked_fill(attn_mask == 0, -1e4) | |
align = F.softmax(score, dim = -1) | |
atten = self.drop(align) | |
self.atten = atten | |
matrix = self.get_relative_matrix(self.emb_rel_v, t).transpose(1,2) # [1,k,2t-1] | |
weight = self.abs_to_rel(atten).transpose(2,3) # [b,h,2t-1,t] | |
output = torch.matmul(value, atten) + torch.matmul(matrix.unsqueeze(0), weight) # [b,h,k,t] | |
x = self.conv_o(output.contiguous().reshape(b,c,t)) | |
return x | |
def get_relative_matrix(self, emb_rel_k, t): | |
s = self.window_size | |
pad_size = max(t - s - 1, 0) | |
start = max(s+1-t, 0) | |
emb_rel_k = F.pad(emb_rel_k, (0,0, pad_size, pad_size)) | |
return emb_rel_k[:,start:start+2*t+1] | |
def rel_to_abs(self, x): | |
b,h,t,_= x.shape | |
x = F.pad(x, (0,1)).reshape(b,h,2*t*t) | |
x = F.pad(x, (0,t-1)).reshape(b,h,t+1, 2*t-1)[:,:,:t,t-1:] | |
return x | |
def abs_to_rel(self, x): | |
b,h,t,t = x.shape | |
x = F.pad(x, (0, t-1)).reshape(b,h,2*t*t-t) | |
x = F.pad(x, (t,0)).reshape(b,h,t,2*t)[:,:,:,1:] | |
return x | |
class FFN(nn.Module): | |
def __init__(self, h_c, f_c, k_s, p = 0.1): | |
super().__init__() | |
self.conv1 = nn.Conv1d(h_c, f_c, k_s, padding=k_s//2) | |
self.conv2 = nn.Conv1d(f_c, h_c, k_s, padding=k_s//2) | |
self.drop = nn.Dropout(p=p) | |
def forward(self, x, x_mask = None): | |
x = self.conv2(self.drop(F.relu(self.conv1(x*x_mask)))*x_mask) | |
return x * x_mask | |
class DurationPredictor(nn.Module): | |
def __init__(self, in_c, f_c, k_s, p=0.1): | |
super().__init__() | |
self.block1 = nn.Sequential(nn.Conv1d(in_c, f_c, k_s, padding=k_s//2), | |
nn.ReLU(), | |
Layernorm(f_c), | |
nn.Dropout(p=p)) | |
self.block2 = nn.Sequential(nn.Conv1d(f_c, f_c, k_s, padding=k_s//2), | |
nn.ReLU(), | |
Layernorm(f_c), | |
nn.Dropout(p=p)) | |
self.proj = nn.Conv1d(f_c, 1, 1) | |
def forward(self, x, x_mask): | |
x = self.block1(x * x_mask) | |
x = self.block2(x * x_mask) | |
x = self.proj(x * x_mask) | |
return x * x_mask | |
######################################### decoder ############################################## | |
# static file system(reasoning the type of tensor), optimizing computation graph, complie before functioning >> to accelate the speed | |
def fuse_tan_sig_add(x:torch.Tensor, mid:int) -> torch.Tensor: | |
a, b = x[:, :mid, :], x[:, mid:, :] | |
return torch.sigmoid(a) * torch.tanh(b) | |
class WN(nn.Module): # non-casual wavenet without dilation | |
def __init__(self, hi_c, k_s, d_l = 1, layers = 3, p=0.05): | |
super().__init__() | |
self.hi_c = hi_c | |
self.resblocks=nn.ModuleList() | |
self.skipblocks=nn.ModuleList() | |
self.drop = nn.Dropout(p=p) | |
for _ in range(layers): | |
res_layer = nn.Conv1d(hi_c, 2*hi_c, k_s, dilation=d_l, padding=k_s//2) | |
res_layer = nn.utils.weight_norm(res_layer, name = 'weight') | |
self.resblocks.append(res_layer) | |
if _ ==2: | |
skip_layer = nn.Conv1d(hi_c, hi_c, 1) # last layer | |
else: | |
skip_layer = nn.Conv1d(hi_c, 2*hi_c, 1) | |
skip_layer = nn.utils.weight_norm(skip_layer, name = 'weight') | |
self.skipblocks.append(skip_layer) | |
def forward(self, x, x_mask = None): | |
mid = self.hi_c | |
end = torch.zeros_like(x, dtype=x.dtype) | |
for i in range(len(self.resblocks)): | |
x = self.drop(self.resblocks[i](x)) # [b, 2c, t] | |
x = fuse_tan_sig_add(x, mid) # [b, c, t] | |
y = self.skipblocks[i](x) | |
if i == 2: | |
end = end + y # last layer | |
else: | |
x = (x + y[:, :mid, :]) * x_mask | |
end = end + y[:, mid:, :] | |
return end * x_mask | |
def skip(self): | |
for layer1, layer2 in zip(self.resblocks, self.skipblocks): | |
nn.utils.remove_weight_norm(layer1) | |
nn.utils.remove_weight_norm(layer2) | |
class Couplinglayer(nn.Module): | |
def __init__(self, in_c, hi_c, k_s, d_l = 1): | |
super().__init__() | |
s_proj = nn.Conv1d(in_c//2, hi_c, 1) | |
self.start = nn.utils.weight_norm(s_proj, name = 'weight') | |
# Initializing last layer to 0 makes the affine coupling layers | |
# do nothing at first. It helps to stabilze training. from glow paper | |
self.end = nn.Conv1d(hi_c, in_c, 1) | |
self.end.weight.data.zero_() | |
self.end.bias.data.zero_() | |
self.wn = WN(hi_c, k_s, d_l) | |
# y = x * logs + t | |
def forward(self, x, x_mask=None, reverse = False): | |
if x_mask is None: | |
x_mask = 1 | |
mid = x.shape[1]//2 # divide channels by 2 | |
x_0, x_1 = x[:, :mid, :], x[:, mid:, :] | |
z_1 = self.end(self.wn(self.start(x_1) * x_mask, x_mask)) | |
logs, t = z_1[:,mid:,:], z_1[:, :mid, :] | |
if reverse: | |
x_0 = torch.exp(-logs)*(x_0 - t) * x_mask | |
logdet = None | |
else : | |
x_0 = torch.exp(logs + 1e-4) * x_0 + t | |
logdet = torch.sum(logs * x_mask, [1,2]) # sum(log(s)) | |
z = torch.cat([x_0, x_1], dim = 1) | |
return z, logdet | |
def skip(self): | |
self.wn.skip() | |
class InvConvNear(nn.Module): | |
def __init__(self, splits = 4): | |
super().__init__() | |
self.splits = splits | |
w_init = torch.linalg.qr(torch.randn((splits, splits)).normal_())[0] # othonormal vector matrix | |
if torch.det(w_init) < 0: | |
w_init[0,:] = -w_init[0,:] | |
self.weight = nn.Parameter(w_init) | |
def forward(self, x, x_mask=None, reverse = False): | |
b, c, t = x.shape | |
if x_mask is None: | |
x_mask = 1 | |
x_len = torch.ones(b) * t # [b] | |
else: | |
x_len = torch.sum(x_mask, [1,2]) | |
s = self.splits | |
x = x.reshape(b, 2, c//s, s//2, t) # split channels into 2 groups | |
x = x.permute(0,1,3,2,4).contiguous().reshape(b, s, c//s, t) | |
if reverse: | |
if hasattr(self, "weight_inv"): | |
weight = self.weight_inv | |
weight = torch.inverse(self.weight).to(dtype=self.weight.dtype) | |
logdet = None | |
else: | |
weight = self.weight | |
logdet = torch.logdet(weight) * (c//s) * x_len # h*w*log(det(W)) since there's no necesserity for decomposition | |
weight = weight.unsqueeze(-1).unsqueeze(-1) | |
z = F.conv2d(x, weight) # z = matmul(weight, x_i,j) for i,j in h = c//s, w = t | |
z = z.reshape(b, 2, s//2, c//s, t).permute(0,1,3,2,4).contiguous().reshape(b, c, t) * x_mask | |
return z, logdet | |
def skip(self): | |
self.weigth_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) | |
class ActNorm(nn.Module): | |
def __init__(self, hi_c, ddi = False): # data dependent initialization | |
super().__init__() | |
self.logs = nn.Parameter(torch.zeros(1, hi_c, 1)) | |
self.bias = nn.Parameter(torch.zeros(1, hi_c, 1)) | |
self.ddi = ddi | |
def forward(self, x, x_mask = None, reverse = False): | |
b, _, t = x.shape | |
if x_mask is None: | |
x_mask = torch.ones(b,1,t).to(device= x.device, dtype = x.dtype) | |
x_len = torch.sum(x_mask, [1, 2]) | |
if self.ddi: | |
self.initialize(x, x_mask) | |
self.ddi = False | |
# y = exp(logs) * x + bias > normalization in channel dim | |
if reverse: | |
z = (x - self.bias) * torch.exp(-self.logs) * x_mask | |
logdet = None | |
else: | |
z = (torch.exp(self.logs) * x + self.bias) * x_mask | |
logdet = torch.sum(self.logs, [1,2])* x_len | |
return z, logdet | |
def initialize(self, x, x_mask): | |
with torch.no_grad(): | |
n = torch.sum(x_mask, [0,2]) | |
m = torch.sum(x * x_mask, [0,2])/n | |
m_s = torch.sum(x * x * x_mask, [0,2])/n | |
v = m_s - m**2 | |
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) | |
init_bias = (-m/torch.exp(-logs)).reshape(*self.bias.shape).to(dtype = self.bias.dtype) # -m/s | |
init_logs = (-logs).reshape(*self.logs.shape).to(dtype = self.logs.dtype) # -logs | |
self.bias.data.copy_(init_bias) | |
self.logs.data.copy_(init_logs) | |
def set_ddi(self): | |
self.ddi = True | |
def skip(self): | |
pass |