Spaces:
Runtime error
Runtime error
from torch import nn | |
import numpy as np | |
import torch.nn.functional as F | |
from torch.nn.utils import weight_norm | |
import math | |
import torch | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
symbol_length = 73 | |
class GlowTTS(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.encoder = Encoder() | |
self.decoder = Decoder() | |
def forward(self, text, text_len, mel=None, mel_len=None, inference=False): | |
""" | |
=====inputs===== | |
text: (B, T) | |
text_len: (B) list | |
mel: (B, 80, F) | |
mel_len: (B) list | |
inference: True/False | |
=====outputs===== | |
(tuple) (z, z_mean, z_log_std, log_det, z_mask) | |
z(training) or y(inference): (B, 80, F) | z: latent representation, y: mel-spectrogram | |
z_mean: (B, 80, F) | |
z_log_std: (B, 80, F) | |
log_det: (B) or None | |
z_mask: (B, 1, F) | |
(tuple) (x_mean, x_log_std, x_mask) | |
x_mean: (B, 80, T) | |
x_log_std: (B, 80, T) | |
x_mask: (B, 1, T) | |
(tuple) (attention_alignment, x_log_dur, log_d) | |
attention_alignment: (B, T, F) | |
x_log_dur: (B, 1, T) | ์ถ์ธกํ duration์ log scale | |
log_d: (B, 1, T) | ์ ์ ํ๋ค๊ณ ์ถ์ธกํ alignment์์์ duration์ log scale | |
""" | |
x_mean, x_log_std, x_log_dur, x_mask = self.encoder(text, text_len) | |
# x_std, x_dur ์ log๋ฅผ ๋ถ์ธ ์ด์ ๋, ๋ ผ๋ฌธ ์ ์์ ๊ตฌํ์์๋ log๊ฐ ์ทจํด์ง ๊ฐ์ผ๋ก ๊ฐ์ฃผํ๊ธฐ ๋๋ฌธ์ด๋ค. | |
y, y_len = mel, mel_len | |
if not inference: # training | |
y_max_len = y.size(2) | |
else: # inference | |
dur = torch.exp(x_log_dur) * x_mask # (B, 1, T) | |
ceil_dur = torch.ceil(dur) # (B, 1, T) | |
y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B) | |
# ceil_dur์ [1, 2] ์ถ์ ๋ํด sumํ ๋ค ์ต์๊ฐ์ด 1์ด์์ด ๋๋๋ก ์ค์ . ์ ์ long ํ์ ์ผ๋ก ๋ฐํํ๋ค. | |
y_max_len = None | |
# preprocessing | |
if y_max_len is not None: | |
y_max_len = (y_max_len // 2) * 2 # ํ์๋ฉด 1์ ๋นผ์ ์ง์๋ก ๋ง๋ ๋ค. | |
y = y[:, :, :y_max_len] # y_max_len์ ๋ง๊ฒ y๋ฅผ ์กฐ์ | |
y_len = (y_len // 2) * 2 # y_len์ด ํ์์ด๋ฉด 1์ ๋นผ์ ์ง์๋ก ๋ง๋ ๋ค. | |
# make the z_mask | |
B = len(y_len) | |
temp_max = max(y_len) | |
z_mask = torch.zeros((B, 1, temp_max), dtype=torch.bool).to(device) # (B, 1, F) | |
for idx, length in enumerate(y_len): | |
z_mask[idx, :, :length] = True | |
# make the attention_mask | |
attention_mask = x_mask.unsqueeze(3) * z_mask.unsqueeze(2) # (B, 1, T, 1) * (B, 1, 1, F) = (B, 1, T, F) | |
# ์ฃผ์: Encoder์ attention_mask์๋ ๋ค๋ฅธ mask์. | |
if not inference: # training | |
z, log_det = self.decoder(y, z_mask, reverse=False) | |
with torch.no_grad(): | |
x_std_squared_root = torch.exp(-2 * x_log_std) # (B, 80, T) | |
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_log_std, [1]).unsqueeze(-1) # [(B, T, F) | |
logp2 = torch.matmul(x_std_squared_root.transpose(1, 2), -0.5 * (z ** 2)) # [(B, T, 80) * (B, 80, F) = (B, T, F) | |
logp3 = torch.matmul((x_mean * x_std_squared_root).transpose(1,2), z) # (B, T, 80) * (B, 80, F) = (B, T, F) | |
logp4 = torch.sum(-0.5 * (x_mean ** 2) * x_std_squared_root, [1]).unsqueeze(-1) # (B, T, F) | |
logp = logp1 + logp2 + logp3 + logp4 # (B, T, F) | |
""" | |
logp๋ normal distribution N(x_mean, x_std)์ maximum log-likelihood์ด๋ค. | |
sum(log(N(z;x_mean, x_std)))๋ฅผ ์ ๊ท๋ถํฌ ์์ ์ด์ฉํ์ฌ ๋ถ๋ฐฐ๋ฒ์น์ผ๋ก ํ์ด๋ด๋ฉด ์์ ๊ฐ์ ์์ด ๋์ถ๋๋ค. | |
""" | |
attention_alignment = maximum_path(logp, attention_mask.squeeze(1)).detach() # alignment (B, T, F) | |
z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80) | |
z_mean = z_mean.transpose(1, 2) # (B, 80, F) | |
z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80) | |
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F) | |
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์์ ํ์ฑ๋ duration์ log scale | |
return (z, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d) | |
else: # inference | |
# generate_path (make attention_alignment using ceil(x_dur)) | |
attention_alignment = generate_path(ceil_dur.squeeze(1), attention_mask.squeeze(1)) # (B, T, F) | |
z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80) | |
z_mean = z_mean.transpose(1, 2) # (B, 80, F) | |
z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80) | |
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F) | |
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์์ ํ์ฑ๋ duration์ log scale | |
z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean)) * z_mask # z(latent representation) ์์ฑ | |
y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram ์์ฑ | |
return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d) | |
##### ์๋ ๋ ผ๋ฌธ์ ๊ตฌํ์ด ํจ์ฌ ๋น ๋ฅด๋ค. ์ด ๋ ผ๋ฌธ ๊ตฌํ์ ๋ณด๊ณ ์์ ๊ตฌํ์ ๋ณ๊ฒฝํ ํ์๊ฐ ์๋ค. ##### | |
def maximum_path(value, mask, max_neg_val=-np.inf): | |
""" Numpy-friendly version. It's about 4 times faster than torch version. | |
value: [b, t_x, t_y] | |
mask: [b, t_x, t_y] | |
""" | |
value = value * mask | |
device = value.device | |
dtype = value.dtype | |
value = value.cpu().detach().numpy() | |
mask = mask.cpu().detach().numpy().astype(bool) | |
b, t_x, t_y = value.shape | |
direction = np.zeros(value.shape, dtype=np.int64) | |
v = np.zeros((b, t_x), dtype=np.float32) | |
x_range = np.arange(t_x, dtype=np.float32).reshape(1,-1) | |
for j in range(t_y): | |
v0 = np.pad(v, [[0,0],[1,0]], mode="constant", constant_values=max_neg_val)[:, :-1] | |
v1 = v | |
max_mask = (v1 >= v0) | |
v_max = np.where(max_mask, v1, v0) | |
direction[:, :, j] = max_mask | |
index_mask = (x_range <= j) | |
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val) | |
direction = np.where(mask, direction, 1) | |
path = np.zeros(value.shape, dtype=np.float32) | |
index = mask[:, :, 0].sum(1).astype(np.int64) - 1 | |
index_range = np.arange(b) | |
for j in reversed(range(t_y)): | |
path[index_range, index, j] = 1 | |
index = index + direction[index_range, index, j] - 1 | |
path = path * mask.astype(np.float32) | |
path = torch.from_numpy(path).to(device=device, dtype=dtype) | |
return path | |
def generate_path(duration, mask): | |
""" | |
duration: [b, t_x] | |
mask: [b, t_x, t_y] | |
""" | |
device = duration.device | |
b, t_x, t_y = mask.shape # (B, T, F) | |
cum_duration = torch.cumsum(duration, 1) # ๋์ ํฉ, (B, T) | |
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) # (B, T, F) | |
cum_duration_flat = cum_duration.view(b * t_x) # (B*T) | |
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) # (B*T, F) | |
path = path.view(b, t_x, t_y) # (B, T, F) | |
path = path.to(torch.float32) | |
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:,:-1] # (B, T, F) # T์ ์ฐจ์ ๋งจ ์์ -1ํ๋ค. | |
path = path * mask | |
return path | |
def sequence_mask(length, max_length=None): | |
if max_length is None: | |
max_length = length.max() | |
x = torch.arange(max_length, dtype=length.dtype, device=length.device) | |
return x.unsqueeze(0) < length.unsqueeze(1) | |
def convert_pad_shape(pad_shape): | |
l = pad_shape[::-1] # [[0, 0], [p, p], [0, 0]] | |
pad_shape = [item for sublist in l for item in sublist] # [0, 0, p, p, 0, 0] | |
return pad_shape | |
def MAS(path, logp, T_max, F_max): | |
""" | |
Glow-TTS์ ๋ชจ๋์ธ maximum_path์ ๋ชจ๋ | |
MAS ์๊ณ ๋ฆฌ์ฆ์ ์ํํ๋ ํจ์์ด๋ค. | |
=====inputs===== | |
path: (T, F) | |
logp: (T, F) | |
T_max: (1) | |
F_max: (1) | |
=====outputs===== | |
path: (T, F) | 0๊ณผ 1๋ก ๊ตฌ์ฑ๋ alignment | |
""" | |
neg_inf = -1e9 # negative infinity | |
# forward | |
for j in range(F_max): | |
for i in range(max(0, T_max + j - F_max), min(T_max, j + 1)): # ํํ์ฌ๋ณํ์ ์๊ฐํ๋ผ. | |
# Q_i_j-1 (current) | |
if i == j: | |
Q_cur = neg_inf | |
else: | |
Q_cur = logp[i, j-1] # j=0์ด๋ฉด i๋ 0์ด๋ฏ๋ก j-1์ ์ฌ์ฉํด๋ ๋๋ค. | |
# Q_i-1_j-1 (previous) | |
if i==0: | |
if j==0: | |
Q_prev = 0. # i=0, j=0์ธ ๊ฒฝ์ฐ์๋ logp ๊ฐ๋ง ๋ฐ์ํด์ผ ํ๋ค. | |
else: | |
Q_prev = neg_inf # i=0์ธ ๊ฒฝ์ฐ์๋ Q_i-1_j-1์ ๋ฐ์ํ์ง ์์์ผ ํ๋ค. | |
else: | |
Q_prev = logp[i-1, j-1] | |
# logp์ Q๋ฅผ ๊ฐฑ์ ํ๋ค. | |
logp[i, j] = max(Q_cur, Q_prev) + logp[i, j] | |
# backtracking | |
idx = T_max - 1 | |
for j in range(F_max-1, -1, -1): # F_max-1๋ถํฐ -1๊น์ง(-1 ํฌํจ ์์ด 0๊น์ง) -1์ฉ ๊ฐ์ | |
path[idx, j] = 1 | |
if idx != 0: | |
if (logp[idx, j-1] < logp[idx-1, j-1]) or (idx == j): | |
idx -= 1 | |
return path | |
def maximum_path(logp, attention_mask): | |
""" | |
Glow-TTS์ ์ฌ์ฉ๋๋ ๋ชจ๋ | |
MAS๋ฅผ ์ฌ์ฉํ์ฌ alignment๋ฅผ ์ฐพ์์ฃผ๋ ์ญํ ์ ํ๋ค. | |
๋ ผ๋ฌธ ์ ์ ๊ตฌํ์์๋ cpython์ ์ด์ฉํ์ฌ ๋ณ๋ ฌ ์ฒ๋ฆฌ๋ฅผ ๊ตฌํํ ๋ฏ ํ๋ | |
์ฌ๊ธฐ์์๋ python๋ง์ ์ด์ฉํ์ฌ ๊ตฌํํ์๋ค. | |
=====inputs===== | |
logp: (B, T, F) | N(x_mean, x_std)์ log-likelihood | |
attention_mask: (B, T, F) | |
=====outputs===== | |
path: (B, T, F) | alignment | |
""" | |
B = logp.shape[0] | |
logp = logp * attention_mask | |
# ๊ณ์ฐ์ CPU์์ ์คํ๋๋๋ก ํ๊ธฐ ์ํด ๊ธฐ์กด์ device๋ฅผ ์ ์ฅํ๊ณ .cpu().numpy()๋ฅผ ํ๋ค. | |
logp_device = logp.device | |
logp_type = logp.dtype | |
logp = logp.data.cpu().numpy().astype(np.float32) | |
attention_mask = attention_mask.data.cpu().numpy() | |
path = np.zeros_like(logp).astype(np.int32) # (B, T, F) | |
T_max = attention_mask.sum(1)[:, 0].astype(np.int32) # (B) | |
F_max = attention_mask.sum(2)[:, 0].astype(np.int32) # (B) | |
# MAS ์๊ณ ๋ฆฌ์ฆ | |
for idx in range(B): | |
path[idx] = MAS(path[idx], logp[idx], T_max[idx], F_max[idx]) # (T, F) | |
return torch.from_numpy(path).to(device=logp_device, dtype=logp_type) | |
def generate_path(ceil_dur, attention_mask): | |
""" | |
Glow-TTS์ ์ฌ์ฉ๋๋ ๋ชจ๋ | |
inference ๊ณผ์ ์์ alignment๋ฅผ ๋ง๋ค์ด๋ธ๋ค. | |
=====input===== | |
ceil_dur: (B, T) | ์ถ๋ก ํ duration์ ceil ์ฐ์ฐํ ๊ฒ | ex) [[2, 1, 2, 2, ...], [1, 2, 1, 3, ...], ...] | |
attention_mask: (B, T, F) | |
=====output===== | |
path: (B, T, F) | alignment | |
""" | |
B, T, Frame = attention_mask.shape | |
cum_dur = torch.cumsum(ceil_dur, 1) | |
cum_dur = cum_dur.to(torch.int32) # (B, T) | ๋์ ํฉ | ex) [[2, 3, 5, 7, ...], [1, 3, 4, 7, ...], ...] | |
path = torch.zeros(B, T, Frame).to(ceil_dur.device) # (B, T, F) | all False(0) | |
# make the sequence_mask | |
for b, batch_cum_dur in enumerate(cum_dur): | |
for t, each_cum_dur in enumerate(batch_cum_dur): | |
path[b, t, :each_cum_dur] = torch.ones((1, 1, each_cum_dur)).to(ceil_dur.device) | |
# cum_dur๋ก๋ถํฐ True(1)๋ฅผ path์ ์๊ฒจ๋ฃ๋๋ค. | |
path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1] # (B, T, F) | |
""" | |
ex) batch๋ฅผ ์ ์ ์ ์ธํด๋๊ณ ์์๋ฅผ ๋ ๋ค. | |
[[1, 1, 0, 0, 0, 0, 0], [[0, 0, 0, 0, 0, 0, 0], [[1, 1, 0, 0, 0, 0, 0], | |
[1, 1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], = [0, 0, 1, 0, 0, 0, 0], | |
[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0], | |
[1, 1, 1, 1, 1, 1, 1]] [1, 1, 1, 1, 1, 0, 0]] [0, 0, 0, 0, 0, 1, 1]] | |
""" | |
path = path * attention_mask | |
return path | |
class Decoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.flows = nn.ModuleList() | |
for i in range(12): | |
self.flows.append(ActNorm()) | |
self.flows.append(InvertibleConv()) | |
self.flows.append(AffineCouplingLayer()) | |
def forward(self, x, x_mask, reverse=False): | |
""" | |
=====inputs===== | |
x: (B, 80, F) | mel-spectrogram(Direct) OR latent representation(Reverse) | |
x_mask: (B, 1, F) | |
=====outputs===== | |
z: (B, 80, F) | latent representation(Direct) OR mel-spectrogram(Reverse) | |
total_log_det: (B) or None | log determinant | |
""" | |
if not reverse: | |
flows = self.flows | |
total_log_det = 0 | |
else: | |
flows = reversed(self.flows) | |
total_log_det = None | |
x, x_mask = Squeeze(x, x_mask) # (B, 80, F) -> (B, 160, F//2) | (B, 1, F) -> (B, 1, F//2) | |
for f in flows: | |
if not reverse: | |
x, log_det = f(x, x_mask, reverse=reverse) | |
total_log_det += log_det | |
else: | |
x, _ = f(x, x_mask, reverse=reverse) | |
x, x_mask = Unsqueeze(x, x_mask) # (B, 160, F//2) -> (B, 80, F) | (B, 1, F//2) -> (B, 1, F) | |
return x, total_log_det | |
""" | |
Decoder๋ Glow: Generative Flow with Invertible 1ร1 Convolutions ๋ ผ๋ฌธ์ ๊ธฐ๋ณธ ๊ตฌ์กฐ๋ฅผ ๋ฐ๋ผ๊ฐ๋ค. | |
Glow ๋ ผ๋ฌธ: https://arxiv.org/pdf/1807.03039.pdf | |
""" | |
def Squeeze(x, x_mask): | |
""" | |
Decoder์ preprocessing | |
=====inputs===== | |
x: (B, 80, F) | mel_spectrogram or latent representation | |
x_mask: (B, 1, F) | |
=====outputs===== | |
x: (B, 160, F//2) | F//2 = [F/2] ([]: ๊ฐ์ฐ์ค ๊ธฐํธ) | |
x_mask: (B, 160, F//2) | |
""" | |
B, C, F = x.size() | |
x = x[:, :, :(F//2)*2] # F๊ฐ ํ์์ด๋ฉด ๋งจ ๋ค ํ frame์ ๋ฒ๋ฆผ. | |
x = x.view(B, C, F//2, 2) # (B, 80, F//2, 2) | |
x = x.permute(0, 3, 1, 2).contiguous() # (B, 2, 80, F//2) | |
x = x.view(B, C*2, F//2) # (B, 160, F//2) | |
x_mask = x_mask[:, :, 1::2] # (B, 1, F//2) frame์ 1๋ถํฐ ํ์นธ์ฉ ๊ฑด๋๋ด๋ค. | |
x = x * x_mask # masking | |
return x, x_mask | |
class ActNorm(nn.Module): | |
""" | |
Decoder์ 1๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
self.log_s = nn.Parameter(torch.zeros(1, 160, 1)) # Glow ๋ ผ๋ฌธ์ s์์ log๋ฅผ ์ทจํ ๊ฒ์ด๋ค. ์ฆ, log[s] | |
self.bias = nn.Parameter(torch.zeros(1, 160, 1)) | |
def forward(self, x, x_mask, reverse=False): | |
""" | |
=====inputs===== | |
x: (B, 160, F//2) | mel_spectrogram features | |
x_mask: (B, 1, F//2) | mel_spectrogram features์ mask. (Decoder์ Squeeze์์ ๋ณํ๋จ.) | |
=====outputs===== | |
z: (B, 160, F//2) | |
log_det: (B) or None | log_determinant, reverse=True์ด๋ฉด None ๋ฐํ | |
""" | |
x_len = torch.sum(x_mask, [1, 2]) # (B) | 1, 2์ฐจ์์ ๊ฐ์ ๋ํ๋ค. cf. [1, 2] ๋์ [2]๋ง ์ฌ์ฉํ๋ฉด shape๊ฐ (B, 1)์ด ๋๋ค. | |
if not reverse: | |
z = (x * torch.exp(self.log_s) + self.bias) * x_mask # function & masking | |
log_det = x_len * torch.sum(self.log_s) # log_determinant | |
# Glow ๋ ผ๋ฌธ์ Table 1์ ํ์ธํ๋ผ. log_s๋ฅผ log[s]๋ผ ๋ณผ ์ ์๋ค. | |
# determinant ๋์ log_determinant๋ฅผ ์ฌ์ฉํ๋ ์ด์ ๋ det๋ณด๋ค ์์ ์์น์ ์ ์ ๊ณ์ฐ๋ ๋๋ฌธ์ผ๋ก ์ถ์ธก๋๋ค. | |
else: | |
z = ((x - self.bias) / torch.exp(self.log_s)) * x_mask # inverse function & masking | |
log_det = None | |
return z, log_det | |
class InvertibleConv(nn.Module): | |
""" | |
Decoder์ 2๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
Q = torch.linalg.qr(torch.FloatTensor(4, 4).normal_())[0] # (4, 4) | |
""" | |
torch.FloatTensor(4, 4).normal_(): ์ ๊ท๋ถํฌ N(0, 1)์์ ๋ฌด์์๋ก ์ถ์ถํ 4x4 matrix | |
Q, R = torch.linalg.qr(W): QR๋ถํด | Q: ์ง๊ต ํ๋ ฌ, R: upper traiangular ํ๋ ฌ cf. det(Q) = 1 or -1 | |
""" | |
if torch.det(Q) < 0: | |
Q[:, 0] = -1 * Q[:, 0] # 0๋ฒ์งธ ์ด์ ๋ถํธ๋ฅผ ๋ฐ๊ฟ์ det(Q) = -1๋ก ๋ง๋ ๋ค. | |
self.W = nn.Parameter(Q) | |
def forward(self, x, x_mask, reverse=False): | |
""" | |
=====inputs===== | |
x: (B, 160, F//2) | |
x_mask: (B, 1, F//2) | |
=====outputs===== | |
z: (B, 160, F//2) | |
log_det: (B) or None | |
""" | |
B, C, f = x.size() # B, 160, F//2 | |
x_len = torch.sum(x_mask, [1, 2]) # (B) | |
# channel mixing | |
x = x.view(B, 2, C//4, 2, f) # (B, 2, 40, 2, F//2) | |
x = x.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 2, 40, F//2) | |
x = x.view(B, 4, C//4, f) # (B, 4, 40, F//2) | |
# ํธ์์ log_det๋ถํฐ ๊ตฌํ๋ค. | |
if not reverse: | |
weight = self.W | |
log_det = (C/4) * x_len * torch.logdet(self.W) # (B) | torch.logdet(W): log(det(W)) | |
# height = C/4, width = x_len ์ธ ์ํฉ์์ ๊ณ ๋ คํ๋ฉด Glow ๋ ผ๋ฌธ์ log_determinant ์๊ณผ ๊ฐ๋ค. | |
else: | |
weight = torch.linalg.inv(self.W) # inverse matrix | |
log_det = None | |
weight = weight.view(4, 4, 1, 1) | |
z = F.conv2d(x, weight) # (B, 4, 40, F//2) * (4, 4, 1, 1) -> (B, 4, 40, F//2) | |
""" | |
F.conv2d(x, weight)์ convolution ์ฐ์ฐ์ ๋ค์๊ณผ ๊ฐ์ด ์๊ฐํด์ผ ํ๋ค. | |
(B, 4, 40, F//2): (batch_size, in_channels, height, width) | |
(4, 4, 1, 1): (out_channels, in_channels/groups, kernel_height, kernel_width) | |
์ฆ, nn.Conv2d(4, 4, kernel_size=(1, 1))์ธ ์ํฉ์ ๊ฐ์ค์น๋ฅผ ์ค ๊ฒ์ด๋ค. | |
""" | |
# channel unmixing | |
z = z.view(B, 2, 2, C//4, f) # (B, 4, 40, F//2) -> (B, 2, 2, 40, F//2) | |
z = z.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 40, 2, F//2) | |
z = z.view(B, C, f) * x_mask # (B, 160, F//2) & masking | |
return z, log_det | |
class WN(nn.Module): | |
""" | |
Decoder์ 3๋ฒ์งธ ๋ชจ๋์ธ AffineCouplingLayer์ ๋ชจ๋ | |
ํด๋น ๊ตฌ์กฐ๋ WAVEGLOW: A FLOW-BASED GENERATIVE NETWORK FOR SPEECH SYNTHESIS ๋ก๋ถํฐ ์ ์๋์๋ค. | |
WaveGlow ๋ ผ๋ฌธ: https://arxiv.org/pdf/1811.00002.pdf | |
""" | |
def __init__(self, dilation_rate=1): | |
super().__init__() | |
self.in_layers = nn.ModuleList() | |
self.res_skip_layers = nn.ModuleList() | |
for i in range(4): | |
dilation = dilation_rate ** i # NVIDIA WaveGlow์์๋ dilation_rate=2์ด์ง๋ง, ์ฌ๊ธฐ์์๋ 1์ด๋ฏ๋ก ์๋ฏธ๋ ์๋ค. | |
in_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=5, dilation=dilation, | |
padding=((5-1) * dilation)//2)) # (B, 192, F//2) -> (B, 2*192, F//2) | |
self.in_layers.append(in_layer) | |
if i < 3: | |
res_skip_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=1)) # (B, 192, F//2) -> (B, 2*192, F//2) | |
else: | |
res_skip_layer = weight_norm(nn.Conv1d(192, 192, kernel_size=1)) # (B, 192, F//2) -> (B, 192, F//2) | |
self.res_skip_layers.append(res_skip_layer) | |
self.dropout = nn.Dropout(0.05) | |
def forward(self, x, x_mask): | |
""" | |
=====inputs===== | |
x: (B, 192, F//2) | |
x_mask: (B, 1, F//2) | |
=====outputs===== | |
output: (B, 192, F//2) | |
""" | |
output = torch.zeros_like(x) # (B, 192, F//2) all zeros | |
for i in range(4): | |
x_in = self.in_layers[i](x) # (B, 192, F//2) -> (B, 2*192, F//2) | |
x_in = self.dropout(x_in) # dropout | |
# fused add tanh sigmoid multiply | |
tanh_act = torch.tanh(x_in[:, :192, :]) # (B, 192, F//2) | |
sigmoid_act = torch.sigmoid(x_in[:, 192:, :]) # (B, 192, F//2) | |
acts = sigmoid_act * tanh_act # (B, 192, F//2) | |
x_out = self.res_skip_layers[i](acts) # (B, 192, F//2) -> (B, 2*192, F//2) or [last](B, 192, F//2) | |
if i < 3: | |
x = (x + x_out[:, :192, :]) * x_mask # residual connection & masking | |
output += x_out[:, 192:, :] # add output | |
else: | |
output += x_out # (B, 192, F//2) | |
output = output * x_mask # masking | |
return output | |
class AffineCouplingLayer(nn.Module): | |
""" | |
Decoder์ 3๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
self.start_conv = weight_norm(nn.Conv1d(160//2, 192, kernel_size=1)) # (B, 80, F//2) -> (B, 192, F//2) | |
self.wn = WN() | |
self.end_conv = nn.Conv1d(192, 160, kernel_size=1) # (B, 192, F//2) -> (B, 160, F//2) | |
# end_conv์ ์ด๊ธฐ ๊ฐ์ค์น๋ฅผ 0์ผ๋ก ์ค์ ํ๋ ๊ฒ์ด ์ฒ์์ ํ์ตํ์ง ์๋ ์ญํ ์ ํ๋ฉฐ, ์ด๋ ํ์ต ์์ ํ์ ๋์์ด ๋๋ค. | |
self.end_conv.weight.data.zero_() # weight๋ฅผ 0์ผ๋ก ์ด๊ธฐํ | |
self.end_conv.bias.data.zero_() # bias๋ฅผ 0์ผ๋ก ์ด๊ธฐํ | |
def forward(self, x, x_mask, reverse=False): | |
""" | |
=====inputs===== | |
x: (B, 160, F//2) | |
x_mask: (B, 1, F//2) | |
=====outputs===== | |
z: (B, 160, F//2) | |
log_det: (B) or None | |
""" | |
B, C, f = x.size() # B, 160, F//2 | |
x_0, x_1 = x[:, :C//2, :], x[:, C//2:, :] # split: (B, 80, F//2) x2 | |
x = self.start_conv(x_0) * x_mask # (B, 80, F//2) -> (B, 192, F//2) & masking | |
x = self.wn(x, x_mask) # (B, 192, F//2) | |
out = self.end_conv(x) # (B, 192, F//2) -> (B, 160, F//2) | |
z_0 = x_0 # (B, 80, F//2) | |
m = out[:, :C//2, :] # (B, 80, F//2) | |
log_s = out[:, C//2:, :] # (B, 80, F//2) | |
if not reverse: | |
z_1 = (torch.exp(log_s) * x_1 + m) * x_mask # (B, 80, F//2) | function & masking | |
log_det = torch.sum(log_s * x_mask, [1, 2]) # (B) | |
else: | |
z_1 = (x_1 - m) / torch.exp(log_s) * x_mask # (B, 80, F//2) | inverse function & masking | |
log_det = None | |
z = torch.cat([z_0, z_1], dim=1) # (B, 160, F//2) | |
return z, log_det | |
def Unsqueeze(x, x_mask): | |
""" | |
Decoder์ postprocessing | |
=====inputs===== | |
x: (B, 160, F//2) | |
x_mask: (B, 1, F//2) | |
=====outputs===== | |
x: (B, 80, F) | |
x_mask: (B, 1, F) | |
""" | |
B, C, f = x.size() # B, 160, F//2 | |
x = x.view(B, 2, C//2, f) # (B, 2, 80, F//2) | |
x = x.permute(0, 2, 3, 1).contiguous() # (B, 80, F//2, 2) | |
x = x.view(B, C//2, 2*f) # (B, 160, F) | |
x_mask = x_mask.unsqueeze(3).repeat(1, 1, 1, 2).view(B, 1, 2*f) # (B, 1, F//2, 1) -> (B, 1, F//2, 2) -> (B, 1, F) | |
x = x * x_mask # masking | |
return x, x_mask | |
class Encoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.embedding = nn.Embedding(symbol_length, 192) # (B, T) -> (B, T, 192) | |
nn.init.normal_(self.embedding.weight, 0.0, 192**(-0.5)) # ๊ฐ์ค์น ์ ๊ท๋ถํฌ ์ด๊ธฐํ (N(0, 0.07xx)) | |
self.prenet = PreNet() | |
self.transformer_encoder = TransformerEncoder() | |
self.project_mean = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T) | |
self.project_std = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T) | |
self.duration_predictor = DurationPredictor() | |
def forward(self, text, text_len): | |
""" | |
=====inputs===== | |
text: (B, Max_T) | |
text_len: (B) | |
=====outputs===== | |
x_mean: (B, 80, T) | ํ๊ท , ๋ ผ๋ฌธ ์ ์ ๊ตฌํ์ train.py์์ out_channels๋ฅผ 80์ผ๋ก ์ค์ ํ ๊ฒ์ ์ ์ ์์. | |
x_std: (B, 80, T) | ํ์คํธ์ฐจ | |
x_dur: (B, 1, T) | |
x_mask: (B, 1, T) | |
""" | |
x = self.embedding(text) * math.sqrt(192) # (B, T) -> (B, T, 192) # math.sqrt(192) = 13.xx (์์ ) | |
x = x.transpose(1, 2) # (B, T, 192) -> (B, 192, T) | |
# Make the x_mask | |
x_mask = torch.zeros_like(x[:, 0:1, :], dtype=torch.bool) # (B, 1, T) | |
for idx, length in enumerate(text_len): | |
x_mask[idx, :, :length] = True | |
x = self.prenet(x, x_mask) # (B, 192, T) | |
x = self.transformer_encoder(x, x_mask) # (B, 192, T) | |
# project | |
x_mean = self.project_mean(x) * x_mask # (B, 192, T) -> (B, 80, T) | |
# x_std = self.project_std(x) * x_mask # (B, 192, T) -> (B, 80, T) | |
##### ์๋๋ mean_only๋ฅผ ์ ์ฉํ ๊ฒ์. ##### | |
x_std = torch.zeros_like(x_mean) # x_log_std: (B, 80, T), all zero # log std = 0์ด๋ฏ๋ก std = 1๋ก ๊ณ์ฐ๋จ. | |
# duration predictor | |
x_dp = torch.detach(x) # stop_gradient | |
x_dur = self.duration_predictor(x_dp, x_mask) # (B, 192, T) -> (B, 1, T) | |
return x_mean, x_std, x_dur, x_mask | |
class LayerNorm(nn.Module): | |
""" | |
์ฌ๋ฌ ๊ณณ์์ ์ ๊ทํ(Norm)๋ฅผ ์ํด ์ฌ์ฉ๋๋ ๋ชจ๋. | |
nn.LayerNorm์ด ์ด๋ฏธ pytorch ์์ ๊ตฌํ๋์ด ์์ผ๋, ํญ์ ๋ง์ง๋ง ์ฐจ์์ ์ ๊ทํํ๋ค. | |
๊ทธ๋์ channel์ ๊ธฐ์ค์ผ๋ก ์ ๊ทํํ๋ LayerNorm์ ๋ฐ๋ก ๊ตฌํํ๋ค. | |
""" | |
def __init__(self, channels): | |
""" | |
channels: ์ ๋ ฅ ๋ฐ์ดํฐ์ channel ์ | LayerNorm์ channel ์ฐจ์์ ์ ๊ทํํ๋ค. | |
""" | |
super().__init__() | |
self.channels = channels | |
self.eps = 1e-4 | |
self.gamma = nn.Parameter(torch.ones(channels)) # ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ | |
self.beta = nn.Parameter(torch.zeros(channels)) # ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ | |
def forward(self, x): | |
""" | |
=====inputs===== | |
x: (B, channels, *) | ์ ๊ทํํ ์ ๋ ฅ ๋ฐ์ดํฐ | |
=====outputs===== | |
x: (B, channels, *) | channel ์ฐจ์์ด ์ ๊ทํ๋ ๋ฐ์ดํฐ | |
""" | |
mean = torch.mean(x, dim=1, keepdim=True) # channel ์ฐจ์(index=1)์ ํ๊ท ๊ณ์ฐ, ์ฐจ์์ ์ ์งํ๋ค. | |
variance = torch.mean((x-mean)**2, dim=1, keepdim=True) # ๋ถ์ฐ ๊ณ์ฐ | |
x = (x - mean) * (variance + self.eps)**(-0.5) # (x - m) / sqrt(v) | |
n = len(x.shape) | |
shape = [1] * n | |
shape[1] = -1 # shape = [1, -1, 1] or [1, -1, 1, 1] | |
x = x * self.gamma.view(*shape) + self.beta.view(*shape) # y = x*gamma + beta | |
return x | |
class PreNet(nn.Module): | |
""" | |
Encoder์ 1๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
self.convs = nn.ModuleList() | |
self.norms = nn.ModuleList() | |
self.relu = nn.ReLU() | |
self.dropout = nn.Dropout(0.5) | |
for i in range(3): | |
self.convs.append(nn.Conv1d(192, 192, kernel_size=5, padding=2)) # (B, 192, T) ์ ์ง | |
self.norms.append(LayerNorm(192)) # (B, 192, T) ์ ์ง | |
self.linear = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์ ์ง | linear ์ญํ ์ ํ๋ conv | |
def forward(self, x, x_mask): | |
""" | |
=====inputs===== | |
x: (B, 192, T) | Embedding๋ ์ ๋ ฅ ๋ฐ์ดํฐ | |
x_mask: (B, 1, T) | ๊ธ์ ๊ธธ์ด์ ๋ฐ๋ฅธ mask (๊ธ์๊ฐ ์์ผ๋ฉด True, ์์ผ๋ฉด False๋ก ๊ตฌ์ฑ) | |
=====outputs===== | |
x: (B, 192, T) | |
""" | |
x0 = x | |
for i in range(3): | |
x = self.convs[i](x * x_mask) | |
x = self.norms[i](x) | |
x = self.relu(x) | |
x = self.dropout(x) | |
x = self.linear(x) | |
x = x0 + x # residual connection | |
return x | |
class MultiHeadAttention(nn.Module): | |
""" | |
Encoder ์ค 2๋ฒ์งธ ๋ชจ๋์ธ TransformerEncoder์ 1๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
self.n_heads = 2 | |
self.window_size = 4 | |
self.k_channels = 192 // self.n_heads # 96 | |
self.linear_q = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์ ์ง | |
self.linear_k = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์ ์ง | |
self.linear_v = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์ ์ง | |
nn.init.xavier_uniform_(self.linear_q.weight) | |
nn.init.xavier_uniform_(self.linear_k.weight) | |
nn.init.xavier_uniform_(self.linear_v.weight) | |
relative_std = self.k_channels ** (-0.5) # 0.1xx | |
self.relative_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96) | |
self.relative_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96) | |
self.attention_weights = None | |
self.linear_out = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์ ์ง | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, query, context, attention_mask, self_attention=True): | |
""" | |
=====inputs===== | |
query: (B, 192, T_target) | Glow-TTS์์๋ self-attention๋ง ์ด์ฉํ๋ฏ๋ก query์ context๊ฐ ๋์ผํ ํ ์ x์ด๋ค. | |
context: (B, 192, T_source) | query = context || ์ฌ๊ธฐ์์๋ ํนํ T_source = T_target ์ด๋ค. | |
attention_mask: (B, 1, T, T) | x_mask.unsqueeze(2) * z_mask.unsqueeze(3) | |
self_attention: True/False | self_attention์ผ ๋ relative position representations๋ฅผ ์ ์ฉํ๋ค. ์ฌ๊ธฐ์์๋ ํญ์ True์ด๋ค. | |
# ์ค์ ๋ก๋ query์ context์ ๊ฐ์ ํ ์ x๋ฅผ ์ ๋ ฅํ๋ฉด ๋๋ค. | |
=====outputs===== | |
output: (B, 192, T) | |
""" | |
query = self.linear_q(query) | |
key = self.linear_k(context) | |
value = self.linear_v(context) | |
B, _, T_tar = query.size() | |
T_src = key.size(2) | |
query = query.view(B, self.n_heads, self.k_channels, T_tar).transpose(2, 3) | |
key = key.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3) | |
value = value.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3) | |
# (B, 192, T_src) -> (B, 2, 96, T_src) -> (B, 2, T_src, 96) | |
scores = torch.matmul(query, key.transpose(2, 3)) / (self.k_channels ** 0.5) | |
# (B, 2, T_tar, 96) * (B, 2, 96, T_src) -> (B, 2, T_tar, T_src) | |
if self_attention: # True | |
# Get relative embeddings (relative_keys) (1-1) | |
padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0) | |
start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0) | |
end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4) | |
relative_keys = F.pad(self.relative_k, (0, 0, padding, padding)) | |
# (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96) | |
""" | |
์ ์ฝ๋์ F.pad(input, pad) ์์ pad = (0, 0, padding, padding)์ ๋ค์์ ์๋ฏธํ๋ค. | |
- ์์ (0, 0): input์ -1์ฐจ์์ ์์ผ๋ก 0, ๋ค๋ก 0๋งํผ ํจ๋ฉํ๋ค. | |
- ์์ (padding, padding): input์ -2์ฐจ์์ ์์ผ๋ก padding, ๋ค๋ก padding๋งํผ ํจ๋ฉํ๋ค. | |
์ฆ, F.pad์์ pad๋ ์ญ์์ผ๋ก ์๊ฐํด์ฃผ์ด์ผ ํ๋ค. | |
""" | |
relative_keys = relative_keys[:, start_pos:end_pos, :] # (1, 2T-1, 96) | |
# Matmul with relative keys (2-1) | |
relative_keys = relative_keys.unsqueeze(0).transpose(2, 3) # (1, 2T-1, 96) -> (1, 1, 2T-1, 96) -> (1, 1, 96, 2T-1) | |
x = torch.matmul(query, relative_keys) # (B, 2, T_tar, 96) * (1, 1, 96, 2T_src-1) = (B, 2, T, 2T-1) | |
# self attention์์๋ T_tar = T_src์ด๋ฏ๋ก ์ด๋ฅผ ๋ค๋ฅด๊ฒ ๊ณ ๋ คํ ํ์๊ฐ ์๋ค. | |
# Relative position to absolute position (3-1) | |
T = T_tar # Absolute position to relative position์์๋ ์ฐ์. | |
x = F.pad(x, (0, 1)) # (B, 2, T, 2*T-1) -> (B, 2, T, 2*T) | |
x = x.view(B, self.n_heads, T * 2 * T) # (B, 2, T, 2*T) -> (B, 2. 2T^2) | |
x = F.pad(x, (0, T-1)) # (B, 2, 2T^2 + T - 1) | |
x = x.view(B, self.n_heads, T+1, 2*T-1) # (B, 2, T+1, 2T-1) | |
relative_logits = x[:, :, :T, T-1:] # (B, 2, T, T) | |
# Compute scores | |
scores_local = relative_logits / (self.k_channels ** 0.5) | |
scores = scores + scores_local # (B, 2, T, T) | |
""" | |
์ ์์ Self-Attention with Relative Position Representations ๋ ผ๋ฌธ์ 5๋ฒ ์์ ๊ตฌํํ ๊ฒ์ด๋ค. | |
Relative- ๋ ผ๋ฌธ: https://arxiv.org/pdf/1803.02155.pdf | |
""" | |
scores = scores.masked_fill(attention_mask == 0, -1e-4) # attention_mask๊ฐ 0์ธ ๊ณณ์ -1e-4๋ก ์ฑ์ด๋ค. | |
attention_weights = F.softmax(scores, dim=-1) # (B, 2, T_tar, T_src) # Relative- ๋ ผ๋ฌธ์์์ alpha์ ํด๋นํ๋ค. | |
attention_weights = self.dropout(attention_weights) # dropoutํ๋ ์ด์ ๊ฐ ๋ฌด์์ผ๊น? | |
output = torch.matmul(attention_weights, value) # (B, 2, T_tar, T_src) * (B, 2, T_src, 96) -> (B, 2, T_tar, 96) | |
if self_attention: # True | |
# Absolute position to relative position (3-2) | |
x = F.pad(attention_weights, (0, T-1)) # (B, 2, T, T) -> (B, 2, T, 2T-1) | |
x = x.view((B, self.n_heads, T * (2*T-1))) # (B, 2, 2T^2-T) | |
x = F.pad(x, (T, 0)) # (B, 2, 2T^2) # ์์ ํจ๋ฉ | |
x = x.view((B, self.n_heads, T, 2*T)) # (B, 2, T, 2T) | |
relative_weights = x[:, :, :, 1:] # (B, 2, T, 2T-1) | |
# Get relative embeddings (relative_value) (1-2) # (1-1)๊ณผ ๊ฑฐ์ ๋์ผ | |
padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0) | |
start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0) | |
end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4) | |
relative_values = F.pad(self.relative_v, (0, 0, padding, padding)) | |
# (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96) | |
relative_values = relative_values[:, start_pos:end_pos, :] # (1, 2T-1, 96) | |
# Matmul with relative values (2-2) | |
relative_values = relative_values.unsqueeze(0) # (1, 1, 2T-1, 96) | |
output = output + torch.matmul(relative_weights, relative_values) | |
# (B, 2, T, 2T-1) * (1, 1, 2T-1, 96) = (B, 2, T, 96) | |
""" | |
์ ์์ Self-Attention with Relative Position Representations ๋ ผ๋ฌธ์ 3๋ฒ ์์ ๊ตฌํํ ๊ฒ์ด๋ค. (๋ถ๋ฐฐ๋ฒ์น ์ด์ฉ) | |
Relative- ๋ ผ๋ฌธ: https://arxiv.org/pdf/1803.02155.pdf | |
""" | |
output = output.transpose(2, 3).contiguous().view(B, 192, T_tar) | |
# (B, 2, 96, T) -> ๋ฉ๋ชจ๋ฆฌ์ ์ฐ์ ๋ฐฐ์น -> (B, 192, T) | |
self.attention_weights = attention_weights # (B, 2, T, T) | |
output = self.linear_out(output) | |
return output # (B, 192, T) | |
class FFN(nn.Module): | |
""" | |
Encoder ์ค 2๋ฒ์งธ ๋ชจ๋์ธ TransformerEncoder์ 2๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv1d(192, 768, kernel_size=3, padding=1) # (B, 192, T) -> (B, 768, T) | |
self.relu = nn.ReLU() | |
self.conv2 = nn.Conv1d(768, 192, kernel_size=3, padding=1) # (B, 768, T) -> (B, 192, T) | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, x, x_mask): | |
""" | |
=====inputs===== | |
x: (B, 192, T) | |
x_mask: (B, 1, T) | |
=====outputs===== | |
output: (B, 192, T) | |
""" | |
x = self.conv1(x) | |
x = self.relu(x) | |
x = self.dropout(x) | |
x = self.conv2(x) | |
output = x * x_mask | |
return output | |
class TransformerEncoder(nn.Module): | |
""" | |
Encoder์ 2๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
self.attentions = nn.ModuleList() | |
self.norms1 = nn.ModuleList() | |
self.ffns = nn.ModuleList() | |
self.norms2 = nn.ModuleList() | |
for i in range(6): | |
self.attentions.append(MultiHeadAttention()) | |
self.norms1.append(LayerNorm(192)) | |
self.ffns.append(FFN()) | |
self.norms2.append(LayerNorm(192)) | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, x, x_mask): | |
""" | |
=====inputs===== | |
x: (B, 192, T) | |
x_mask: (B, 1, T) | |
=====outputs===== | |
output: (B, 192, T) | |
""" | |
attention_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(3) | |
# (B, 1, 1, T) * (B, 1, T, 1) = (B, 1, T, T), only consist 0 or 1 | |
for i in range(6): | |
x = x * x_mask | |
y = self.attentions[i](x, x, attention_mask) | |
y = self.dropout(y) | |
x = x + y # residual connection | |
x = self.norms1[i](x) # (B, 192, T) ์ ์ง | |
y = self.ffns[i](x, x_mask) | |
y = self.dropout(y) | |
x = x + y # residual connection | |
x = self.norms2[i](x) | |
output = x * x_mask | |
return output # (B, 192, T) | |
class DurationPredictor(nn.Module): | |
""" | |
Encoder์ 3๋ฒ์งธ ๋ชจ๋ | |
""" | |
def __init__(self): | |
super().__init__() | |
self.conv1 = nn.Conv1d(192, 256, kernel_size=3, padding=1) # (B, 192, T) -> (B, 256, T) | |
self.norm1 = LayerNorm(256) | |
self.conv2 = nn.Conv1d(256, 256, kernel_size=3, padding=1) # (B, 256, T) -> (B, 256, T) | |
self.norm2 = LayerNorm(256) | |
self.linear = nn.Conv1d(256, 1, kernel_size=1) # (B, 256, T) -> (B, 1, T) | |
self.relu = nn.ReLU() | |
self.dropout = nn.Dropout(0.1) | |
def forward(self, x, x_mask): | |
""" | |
=====inputs===== | |
x: (B, 192, T) | |
x_mask: (B, 1, T) | |
=====outputs===== | |
output: (B, 1, T) | |
""" | |
x = self.conv1(x * x_mask) # (B, 192, T) -> (B, 256, T) | |
x = self.relu(x) | |
x = self.norm1(x) | |
x = self.dropout(x) | |
x = self.conv2(x * x_mask) # (B, 256, T) -> (B, 256, T) | |
x = self.relu(x) | |
x = self.norm2(x) | |
x = self.dropout(x) | |
x = self.linear(x * x_mask) # (B, 256, T) -> (B, 1, T) | |
output = x * x_mask | |
return output |