Glow-HiFi-TTS / Tmodel.py
marigold334's picture
Update Tmodel.py (#31)
5e764fc
raw
history blame
No virus
38.6 kB
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import math
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
symbol_length = 73
class GlowTTS(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, text, text_len, mel=None, mel_len=None, inference=False, noise_scale=1., length_scale=1.):
"""
=====inputs=====
text: (B, T)
text_len: (B) list
mel: (B, 80, F)
mel_len: (B) list
inference: True/False
=====outputs=====
(tuple) (z, z_mean, z_log_std, log_det, z_mask)
z(training) or y(inference): (B, 80, F) | z: latent representation, y: mel-spectrogram
z_mean: (B, 80, F)
z_log_std: (B, 80, F)
log_det: (B) or None
z_mask: (B, 1, F)
(tuple) (x_mean, x_log_std, x_mask)
x_mean: (B, 80, T)
x_log_std: (B, 80, T)
x_mask: (B, 1, T)
(tuple) (attention_alignment, x_log_dur, log_d)
attention_alignment: (B, T, F)
x_log_dur: (B, 1, T) | ์ถ”์ธกํ•œ duration์˜ log scale
log_d: (B, 1, T) | ์ ์ ˆํ•˜๋‹ค๊ณ  ์ถ”์ธกํ•œ alignment์—์„œ์˜ duration์˜ log scale
"""
x_mean, x_log_std, x_log_dur, x_mask = self.encoder(text, text_len)
# x_std, x_dur ์— log๋ฅผ ๋ถ™์ธ ์ด์œ ๋Š”, ๋…ผ๋ฌธ ์ €์ž์˜ ๊ตฌํ˜„์—์„œ๋Š” log๊ฐ€ ์ทจํ•ด์ง„ ๊ฐ’์œผ๋กœ ๊ฐ„์ฃผํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
y, y_len = mel, mel_len
if not inference: # training
y_max_len = y.size(2)
else: # inference
dur = torch.exp(x_log_dur) * x_mask * length_scale # (B, 1, T)
ceil_dur = torch.ceil(dur) # (B, 1, T)
y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B)
# ceil_dur์„ [1, 2] ์ถ•์— ๋Œ€ํ•ด sumํ•œ ๋’ค ์ตœ์†Ÿ๊ฐ’์ด 1์ด์ƒ์ด ๋˜๋„๋ก ์„ค์ •. ์ •์ˆ˜ long ํƒ€์ž…์œผ๋กœ ๋ฐ˜ํ™˜ํ•œ๋‹ค.
y_max_len = None
# preprocessing
if y_max_len is not None:
y_max_len = (y_max_len // 2) * 2 # ํ™€์ˆ˜๋ฉด 1์„ ๋นผ์„œ ์ง์ˆ˜๋กœ ๋งŒ๋“ ๋‹ค.
y = y[:, :, :y_max_len] # y_max_len์— ๋งž๊ฒŒ y๋ฅผ ์กฐ์ •
y_len = (y_len // 2) * 2 # y_len์ด ํ™€์ˆ˜์ด๋ฉด 1์„ ๋นผ์„œ ์ง์ˆ˜๋กœ ๋งŒ๋“ ๋‹ค.
# make the z_mask
B = len(y_len)
temp_max = max(y_len)
z_mask = torch.zeros((B, 1, temp_max), dtype=torch.bool).to(device) # (B, 1, F)
for idx, length in enumerate(y_len):
z_mask[idx, :, :length] = True
# make the attention_mask
attention_mask = x_mask.unsqueeze(3) * z_mask.unsqueeze(2) # (B, 1, T, 1) * (B, 1, 1, F) = (B, 1, T, F)
# ์ฃผ์˜: Encoder์˜ attention_mask์™€๋Š” ๋‹ค๋ฅธ mask์ž„.
if not inference: # training
z, log_det = self.decoder(y, z_mask, reverse=False)
with torch.no_grad():
x_std_squared_root = torch.exp(-2 * x_log_std) # (B, 80, T)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_log_std, [1]).unsqueeze(-1) # [(B, T, F)
logp2 = torch.matmul(x_std_squared_root.transpose(1, 2), -0.5 * (z ** 2)) # [(B, T, 80) * (B, 80, F) = (B, T, F)
logp3 = torch.matmul((x_mean * x_std_squared_root).transpose(1,2), z) # (B, T, 80) * (B, 80, F) = (B, T, F)
logp4 = torch.sum(-0.5 * (x_mean ** 2) * x_std_squared_root, [1]).unsqueeze(-1) # (B, T, F)
logp = logp1 + logp2 + logp3 + logp4 # (B, T, F)
"""
logp๋Š” normal distribution N(x_mean, x_std)์˜ maximum log-likelihood์ด๋‹ค.
sum(log(N(z;x_mean, x_std)))๋ฅผ ์ •๊ทœ๋ถ„ํฌ ์‹์„ ์ด์šฉํ•˜์—ฌ ๋ถ„๋ฐฐ๋ฒ•์น™์œผ๋กœ ํ’€์–ด๋‚ด๋ฉด ์œ„์™€ ๊ฐ™์€ ์‹์ด ๋„์ถœ๋œ๋‹ค.
"""
attention_alignment = maximum_path(logp, attention_mask.squeeze(1)).detach() # alignment (B, T, F)
z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
z_mean = z_mean.transpose(1, 2) # (B, 80, F)
z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์—์„œ ํ˜•์„ฑ๋œ duration์˜ log scale
return (z, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
else: # inference
# generate_path (make attention_alignment using ceil(x_dur))
attention_alignment = generate_path(ceil_dur.squeeze(1), attention_mask.squeeze(1)) # (B, T, F)
z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
z_mean = z_mean.transpose(1, 2) # (B, 80, F)
z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment์—์„œ ํ˜•์„ฑ๋œ duration์˜ log scale
z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean) * noise_scale) * z_mask # z(latent representation) ์ƒ์„ฑ
y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram ์ƒ์„ฑ
return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
##### ์•„๋ž˜ ๋…ผ๋ฌธ์˜ ๊ตฌํ˜„์ด ํ›จ์”ฌ ๋น ๋ฅด๋‹ค. ์ด ๋…ผ๋ฌธ ๊ตฌํ˜„์„ ๋ณด๊ณ  ์œ„์˜ ๊ตฌํ˜„์„ ๋ณ€๊ฒฝํ•  ํ•„์š”๊ฐ€ ์žˆ๋‹ค. #####
def maximum_path(value, mask, max_neg_val=-np.inf):
""" Numpy-friendly version. It's about 4 times faster than torch version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""
value = value * mask
device = value.device
dtype = value.dtype
value = value.cpu().detach().numpy()
mask = mask.cpu().detach().numpy().astype(bool)
b, t_x, t_y = value.shape
direction = np.zeros(value.shape, dtype=np.int64)
v = np.zeros((b, t_x), dtype=np.float32)
x_range = np.arange(t_x, dtype=np.float32).reshape(1,-1)
for j in range(t_y):
v0 = np.pad(v, [[0,0],[1,0]], mode="constant", constant_values=max_neg_val)[:, :-1]
v1 = v
max_mask = (v1 >= v0)
v_max = np.where(max_mask, v1, v0)
direction[:, :, j] = max_mask
index_mask = (x_range <= j)
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
direction = np.where(mask, direction, 1)
path = np.zeros(value.shape, dtype=np.float32)
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
index_range = np.arange(b)
for j in reversed(range(t_y)):
path[index_range, index, j] = 1
index = index + direction[index_range, index, j] - 1
path = path * mask.astype(np.float32)
path = torch.from_numpy(path).to(device=device, dtype=dtype)
return path
def generate_path(duration, mask):
"""
duration: [b, t_x]
mask: [b, t_x, t_y]
"""
device = duration.device
b, t_x, t_y = mask.shape # (B, T, F)
cum_duration = torch.cumsum(duration, 1) # ๋ˆ„์ ํ•ฉ, (B, T)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) # (B, T, F)
cum_duration_flat = cum_duration.view(b * t_x) # (B*T)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) # (B*T, F)
path = path.view(b, t_x, t_y) # (B, T, F)
path = path.to(torch.float32)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:,:-1] # (B, T, F) # T์˜ ์ฐจ์› ๋งจ ์•ž์„ -1ํ•œ๋‹ค.
path = path * mask
return path
def sequence_mask(length, max_length=None):
if max_length is None:
max_length = length.max()
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
return x.unsqueeze(0) < length.unsqueeze(1)
def convert_pad_shape(pad_shape):
l = pad_shape[::-1] # [[0, 0], [p, p], [0, 0]]
pad_shape = [item for sublist in l for item in sublist] # [0, 0, p, p, 0, 0]
return pad_shape
def MAS(path, logp, T_max, F_max):
"""
Glow-TTS์˜ ๋ชจ๋“ˆ์ธ maximum_path์˜ ๋ชจ๋“ˆ
MAS ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜์ด๋‹ค.
=====inputs=====
path: (T, F)
logp: (T, F)
T_max: (1)
F_max: (1)
=====outputs=====
path: (T, F) | 0๊ณผ 1๋กœ ๊ตฌ์„ฑ๋œ alignment
"""
neg_inf = -1e9 # negative infinity
# forward
for j in range(F_max):
for i in range(max(0, T_max + j - F_max), min(T_max, j + 1)): # ํ‰ํ–‰์‚ฌ๋ณ€ํ˜•์„ ์ƒ๊ฐํ•˜๋ผ.
# Q_i_j-1 (current)
if i == j:
Q_cur = neg_inf
else:
Q_cur = logp[i, j-1] # j=0์ด๋ฉด i๋„ 0์ด๋ฏ€๋กœ j-1์„ ์‚ฌ์šฉํ•ด๋„ ๋œ๋‹ค.
# Q_i-1_j-1 (previous)
if i==0:
if j==0:
Q_prev = 0. # i=0, j=0์ธ ๊ฒฝ์šฐ์—๋Š” logp ๊ฐ’๋งŒ ๋ฐ˜์˜ํ•ด์•ผ ํ•œ๋‹ค.
else:
Q_prev = neg_inf # i=0์ธ ๊ฒฝ์šฐ์—๋Š” Q_i-1_j-1์„ ๋ฐ˜์˜ํ•˜์ง€ ์•Š์•„์•ผ ํ•œ๋‹ค.
else:
Q_prev = logp[i-1, j-1]
# logp์— Q๋ฅผ ๊ฐฑ์‹ ํ•œ๋‹ค.
logp[i, j] = max(Q_cur, Q_prev) + logp[i, j]
# backtracking
idx = T_max - 1
for j in range(F_max-1, -1, -1): # F_max-1๋ถ€ํ„ฐ -1๊นŒ์ง€(-1 ํฌํ•จ ์—†์ด 0๊นŒ์ง€) -1์”ฉ ๊ฐ์†Œ
path[idx, j] = 1
if idx != 0:
if (logp[idx, j-1] < logp[idx-1, j-1]) or (idx == j):
idx -= 1
return path
def maximum_path(logp, attention_mask):
"""
Glow-TTS์— ์‚ฌ์šฉ๋˜๋Š” ๋ชจ๋“ˆ
MAS๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ alignment๋ฅผ ์ฐพ์•„์ฃผ๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.
๋…ผ๋ฌธ ์ €์ž ๊ตฌํ˜„์—์„œ๋Š” cpython์„ ์ด์šฉํ•˜์—ฌ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ๋ฅผ ๊ตฌํ˜„ํ•œ ๋“ฏ ํ•˜๋‚˜
์—ฌ๊ธฐ์—์„œ๋Š” python๋งŒ์„ ์ด์šฉํ•˜์—ฌ ๊ตฌํ˜„ํ•˜์˜€๋‹ค.
=====inputs=====
logp: (B, T, F) | N(x_mean, x_std)์˜ log-likelihood
attention_mask: (B, T, F)
=====outputs=====
path: (B, T, F) | alignment
"""
B = logp.shape[0]
logp = logp * attention_mask
# ๊ณ„์‚ฐ์€ CPU์—์„œ ์‹คํ–‰๋˜๋„๋ก ํ•˜๊ธฐ ์œ„ํ•ด ๊ธฐ์กด์˜ device๋ฅผ ์ €์žฅํ•˜๊ณ  .cpu().numpy()๋ฅผ ํ•œ๋‹ค.
logp_device = logp.device
logp_type = logp.dtype
logp = logp.data.cpu().numpy().astype(np.float32)
attention_mask = attention_mask.data.cpu().numpy()
path = np.zeros_like(logp).astype(np.int32) # (B, T, F)
T_max = attention_mask.sum(1)[:, 0].astype(np.int32) # (B)
F_max = attention_mask.sum(2)[:, 0].astype(np.int32) # (B)
# MAS ์•Œ๊ณ ๋ฆฌ์ฆ˜
for idx in range(B):
path[idx] = MAS(path[idx], logp[idx], T_max[idx], F_max[idx]) # (T, F)
return torch.from_numpy(path).to(device=logp_device, dtype=logp_type)
def generate_path(ceil_dur, attention_mask):
"""
Glow-TTS์— ์‚ฌ์šฉ๋˜๋Š” ๋ชจ๋“ˆ
inference ๊ณผ์ •์—์„œ alignment๋ฅผ ๋งŒ๋“ค์–ด๋‚ธ๋‹ค.
=====input=====
ceil_dur: (B, T) | ์ถ”๋ก ํ•œ duration์— ceil ์—ฐ์‚ฐํ•œ ๊ฒƒ | ex) [[2, 1, 2, 2, ...], [1, 2, 1, 3, ...], ...]
attention_mask: (B, T, F)
=====output=====
path: (B, T, F) | alignment
"""
B, T, Frame = attention_mask.shape
cum_dur = torch.cumsum(ceil_dur, 1)
cum_dur = cum_dur.to(torch.int32) # (B, T) | ๋ˆ„์ ํ•ฉ | ex) [[2, 3, 5, 7, ...], [1, 3, 4, 7, ...], ...]
path = torch.zeros(B, T, Frame).to(ceil_dur.device) # (B, T, F) | all False(0)
# make the sequence_mask
for b, batch_cum_dur in enumerate(cum_dur):
for t, each_cum_dur in enumerate(batch_cum_dur):
path[b, t, :each_cum_dur] = torch.ones((1, 1, each_cum_dur)).to(ceil_dur.device)
# cum_dur๋กœ๋ถ€ํ„ฐ True(1)๋ฅผ path์— ์ƒˆ๊ฒจ๋„ฃ๋Š”๋‹ค.
path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1] # (B, T, F)
"""
ex) batch๋ฅผ ์ž ์‹œ ์ œ์™ธํ•ด๋‘๊ณ  ์˜ˆ์‹œ๋ฅผ ๋“ ๋‹ค.
[[1, 1, 0, 0, 0, 0, 0], [[0, 0, 0, 0, 0, 0, 0], [[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], = [0, 0, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1]] [1, 1, 1, 1, 1, 0, 0]] [0, 0, 0, 0, 0, 1, 1]]
"""
path = path * attention_mask
return path
class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.flows = nn.ModuleList()
for i in range(12):
self.flows.append(ActNorm())
self.flows.append(InvertibleConv())
self.flows.append(AffineCouplingLayer())
def forward(self, x, x_mask, reverse=False):
"""
=====inputs=====
x: (B, 80, F) | mel-spectrogram(Direct) OR latent representation(Reverse)
x_mask: (B, 1, F)
=====outputs=====
z: (B, 80, F) | latent representation(Direct) OR mel-spectrogram(Reverse)
total_log_det: (B) or None | log determinant
"""
if not reverse:
flows = self.flows
total_log_det = 0
else:
flows = reversed(self.flows)
total_log_det = None
x, x_mask = Squeeze(x, x_mask) # (B, 80, F) -> (B, 160, F//2) | (B, 1, F) -> (B, 1, F//2)
for f in flows:
if not reverse:
x, log_det = f(x, x_mask, reverse=reverse)
total_log_det += log_det
else:
x, _ = f(x, x_mask, reverse=reverse)
x, x_mask = Unsqueeze(x, x_mask) # (B, 160, F//2) -> (B, 80, F) | (B, 1, F//2) -> (B, 1, F)
return x, total_log_det
"""
Decoder๋Š” Glow: Generative Flow with Invertible 1ร—1 Convolutions ๋…ผ๋ฌธ์˜ ๊ธฐ๋ณธ ๊ตฌ์กฐ๋ฅผ ๋”ฐ๋ผ๊ฐ„๋‹ค.
Glow ๋…ผ๋ฌธ: https://arxiv.org/pdf/1807.03039.pdf
"""
def Squeeze(x, x_mask):
"""
Decoder์˜ preprocessing
=====inputs=====
x: (B, 80, F) | mel_spectrogram or latent representation
x_mask: (B, 1, F)
=====outputs=====
x: (B, 160, F//2) | F//2 = [F/2] ([]: ๊ฐ€์šฐ์Šค ๊ธฐํ˜ธ)
x_mask: (B, 160, F//2)
"""
B, C, F = x.size()
x = x[:, :, :(F//2)*2] # F๊ฐ€ ํ™€์ˆ˜์ด๋ฉด ๋งจ ๋’ค ํ•œ frame์„ ๋ฒ„๋ฆผ.
x = x.view(B, C, F//2, 2) # (B, 80, F//2, 2)
x = x.permute(0, 3, 1, 2).contiguous() # (B, 2, 80, F//2)
x = x.view(B, C*2, F//2) # (B, 160, F//2)
x_mask = x_mask[:, :, 1::2] # (B, 1, F//2) frame์„ 1๋ถ€ํ„ฐ ํ•œ์นธ์”ฉ ๊ฑด๋„ˆ๋›ด๋‹ค.
x = x * x_mask # masking
return x, x_mask
class ActNorm(nn.Module):
"""
Decoder์˜ 1๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
self.log_s = nn.Parameter(torch.zeros(1, 160, 1)) # Glow ๋…ผ๋ฌธ์˜ s์—์„œ log๋ฅผ ์ทจํ•œ ๊ฒƒ์ด๋‹ค. ์ฆ‰, log[s]
self.bias = nn.Parameter(torch.zeros(1, 160, 1))
def forward(self, x, x_mask, reverse=False):
"""
=====inputs=====
x: (B, 160, F//2) | mel_spectrogram features
x_mask: (B, 1, F//2) | mel_spectrogram features์˜ mask. (Decoder์˜ Squeeze์—์„œ ๋ณ€ํ˜•๋จ.)
=====outputs=====
z: (B, 160, F//2)
log_det: (B) or None | log_determinant, reverse=True์ด๋ฉด None ๋ฐ˜ํ™˜
"""
x_len = torch.sum(x_mask, [1, 2]) # (B) | 1, 2์ฐจ์›์˜ ๊ฐ’์„ ๋”ํ•œ๋‹ค. cf. [1, 2] ๋Œ€์‹  [2]๋งŒ ์‚ฌ์šฉํ•˜๋ฉด shape๊ฐ€ (B, 1)์ด ๋œ๋‹ค.
if not reverse:
z = (x * torch.exp(self.log_s) + self.bias) * x_mask # function & masking
log_det = x_len * torch.sum(self.log_s) # log_determinant
# Glow ๋…ผ๋ฌธ์˜ Table 1์„ ํ™•์ธํ•˜๋ผ. log_s๋ฅผ log[s]๋ผ ๋ณผ ์ˆ˜ ์žˆ๋‹ค.
# determinant ๋Œ€์‹  log_determinant๋ฅผ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ๋Š” det๋ณด๋‹ค ์ž‘์€ ์ˆ˜์น˜์™€ ์ ์€ ๊ณ„์‚ฐ๋Ÿ‰ ๋•Œ๋ฌธ์œผ๋กœ ์ถ”์ธก๋œ๋‹ค.
else:
z = ((x - self.bias) / torch.exp(self.log_s)) * x_mask # inverse function & masking
log_det = None
return z, log_det
class InvertibleConv(nn.Module):
"""
Decoder์˜ 2๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
Q = torch.linalg.qr(torch.FloatTensor(4, 4).normal_())[0] # (4, 4)
"""
torch.FloatTensor(4, 4).normal_(): ์ •๊ทœ๋ถ„ํฌ N(0, 1)์—์„œ ๋ฌด์ž‘์œ„๋กœ ์ถ”์ถœํ•œ 4x4 matrix
Q, R = torch.linalg.qr(W): QR๋ถ„ํ•ด | Q: ์ง๊ต ํ–‰๋ ฌ, R: upper traiangular ํ–‰๋ ฌ cf. det(Q) = 1 or -1
"""
if torch.det(Q) < 0:
Q[:, 0] = -1 * Q[:, 0] # 0๋ฒˆ์งธ ์—ด์˜ ๋ถ€ํ˜ธ๋ฅผ ๋ฐ”๊ฟ”์„œ det(Q) = -1๋กœ ๋งŒ๋“ ๋‹ค.
self.W = nn.Parameter(Q)
def forward(self, x, x_mask, reverse=False):
"""
=====inputs=====
x: (B, 160, F//2)
x_mask: (B, 1, F//2)
=====outputs=====
z: (B, 160, F//2)
log_det: (B) or None
"""
B, C, f = x.size() # B, 160, F//2
x_len = torch.sum(x_mask, [1, 2]) # (B)
# channel mixing
x = x.view(B, 2, C//4, 2, f) # (B, 2, 40, 2, F//2)
x = x.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 2, 40, F//2)
x = x.view(B, 4, C//4, f) # (B, 4, 40, F//2)
# ํŽธ์˜์ƒ log_det๋ถ€ํ„ฐ ๊ตฌํ•œ๋‹ค.
if not reverse:
weight = self.W
log_det = (C/4) * x_len * torch.logdet(self.W) # (B) | torch.logdet(W): log(det(W))
# height = C/4, width = x_len ์ธ ์ƒํ™ฉ์ž„์„ ๊ณ ๋ คํ•˜๋ฉด Glow ๋…ผ๋ฌธ์˜ log_determinant ์‹๊ณผ ๊ฐ™๋‹ค.
else:
weight = torch.linalg.inv(self.W) # inverse matrix
log_det = None
weight = weight.view(4, 4, 1, 1)
z = F.conv2d(x, weight) # (B, 4, 40, F//2) * (4, 4, 1, 1) -> (B, 4, 40, F//2)
"""
F.conv2d(x, weight)์˜ convolution ์—ฐ์‚ฐ์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ƒ๊ฐํ•ด์•ผ ํ•œ๋‹ค.
(B, 4, 40, F//2): (batch_size, in_channels, height, width)
(4, 4, 1, 1): (out_channels, in_channels/groups, kernel_height, kernel_width)
์ฆ‰, nn.Conv2d(4, 4, kernel_size=(1, 1))์ธ ์ƒํ™ฉ์— ๊ฐ€์ค‘์น˜๋ฅผ ์ค€ ๊ฒƒ์ด๋‹ค.
"""
# channel unmixing
z = z.view(B, 2, 2, C//4, f) # (B, 4, 40, F//2) -> (B, 2, 2, 40, F//2)
z = z.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 40, 2, F//2)
z = z.view(B, C, f) * x_mask # (B, 160, F//2) & masking
return z, log_det
class WN(nn.Module):
"""
Decoder์˜ 3๋ฒˆ์งธ ๋ชจ๋“ˆ์ธ AffineCouplingLayer์˜ ๋ชจ๋“ˆ
ํ•ด๋‹น ๊ตฌ์กฐ๋Š” WAVEGLOW: A FLOW-BASED GENERATIVE NETWORK FOR SPEECH SYNTHESIS ๋กœ๋ถ€ํ„ฐ ์ œ์•ˆ๋˜์—ˆ๋‹ค.
WaveGlow ๋…ผ๋ฌธ: https://arxiv.org/pdf/1811.00002.pdf
"""
def __init__(self, dilation_rate=1):
super().__init__()
self.in_layers = nn.ModuleList()
self.res_skip_layers = nn.ModuleList()
for i in range(4):
dilation = dilation_rate ** i # NVIDIA WaveGlow์—์„œ๋Š” dilation_rate=2์ด์ง€๋งŒ, ์—ฌ๊ธฐ์—์„œ๋Š” 1์ด๋ฏ€๋กœ ์˜๋ฏธ๋Š” ์—†๋‹ค.
in_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=5, dilation=dilation,
padding=((5-1) * dilation)//2)) # (B, 192, F//2) -> (B, 2*192, F//2)
self.in_layers.append(in_layer)
if i < 3:
res_skip_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=1)) # (B, 192, F//2) -> (B, 2*192, F//2)
else:
res_skip_layer = weight_norm(nn.Conv1d(192, 192, kernel_size=1)) # (B, 192, F//2) -> (B, 192, F//2)
self.res_skip_layers.append(res_skip_layer)
self.dropout = nn.Dropout(0.05)
def forward(self, x, x_mask):
"""
=====inputs=====
x: (B, 192, F//2)
x_mask: (B, 1, F//2)
=====outputs=====
output: (B, 192, F//2)
"""
output = torch.zeros_like(x) # (B, 192, F//2) all zeros
for i in range(4):
x_in = self.in_layers[i](x) # (B, 192, F//2) -> (B, 2*192, F//2)
x_in = self.dropout(x_in) # dropout
# fused add tanh sigmoid multiply
tanh_act = torch.tanh(x_in[:, :192, :]) # (B, 192, F//2)
sigmoid_act = torch.sigmoid(x_in[:, 192:, :]) # (B, 192, F//2)
acts = sigmoid_act * tanh_act # (B, 192, F//2)
x_out = self.res_skip_layers[i](acts) # (B, 192, F//2) -> (B, 2*192, F//2) or [last](B, 192, F//2)
if i < 3:
x = (x + x_out[:, :192, :]) * x_mask # residual connection & masking
output += x_out[:, 192:, :] # add output
else:
output += x_out # (B, 192, F//2)
output = output * x_mask # masking
return output
class AffineCouplingLayer(nn.Module):
"""
Decoder์˜ 3๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
self.start_conv = weight_norm(nn.Conv1d(160//2, 192, kernel_size=1)) # (B, 80, F//2) -> (B, 192, F//2)
self.wn = WN()
self.end_conv = nn.Conv1d(192, 160, kernel_size=1) # (B, 192, F//2) -> (B, 160, F//2)
# end_conv์˜ ์ดˆ๊ธฐ ๊ฐ€์ค‘์น˜๋ฅผ 0์œผ๋กœ ์„ค์ •ํ•˜๋Š” ๊ฒƒ์ด ์ฒ˜์Œ์— ํ•™์Šตํ•˜์ง€ ์•Š๋Š” ์—ญํ• ์„ ํ•˜๋ฉฐ, ์ด๋Š” ํ•™์Šต ์•ˆ์ •ํ™”์— ๋„์›€์ด ๋œ๋‹ค.
self.end_conv.weight.data.zero_() # weight๋ฅผ 0์œผ๋กœ ์ดˆ๊ธฐํ™”
self.end_conv.bias.data.zero_() # bias๋ฅผ 0์œผ๋กœ ์ดˆ๊ธฐํ™”
def forward(self, x, x_mask, reverse=False):
"""
=====inputs=====
x: (B, 160, F//2)
x_mask: (B, 1, F//2)
=====outputs=====
z: (B, 160, F//2)
log_det: (B) or None
"""
B, C, f = x.size() # B, 160, F//2
x_0, x_1 = x[:, :C//2, :], x[:, C//2:, :] # split: (B, 80, F//2) x2
x = self.start_conv(x_0) * x_mask # (B, 80, F//2) -> (B, 192, F//2) & masking
x = self.wn(x, x_mask) # (B, 192, F//2)
out = self.end_conv(x) # (B, 192, F//2) -> (B, 160, F//2)
z_0 = x_0 # (B, 80, F//2)
m = out[:, :C//2, :] # (B, 80, F//2)
log_s = out[:, C//2:, :] # (B, 80, F//2)
if not reverse:
z_1 = (torch.exp(log_s) * x_1 + m) * x_mask # (B, 80, F//2) | function & masking
log_det = torch.sum(log_s * x_mask, [1, 2]) # (B)
else:
z_1 = (x_1 - m) / torch.exp(log_s) * x_mask # (B, 80, F//2) | inverse function & masking
log_det = None
z = torch.cat([z_0, z_1], dim=1) # (B, 160, F//2)
return z, log_det
def Unsqueeze(x, x_mask):
"""
Decoder์˜ postprocessing
=====inputs=====
x: (B, 160, F//2)
x_mask: (B, 1, F//2)
=====outputs=====
x: (B, 80, F)
x_mask: (B, 1, F)
"""
B, C, f = x.size() # B, 160, F//2
x = x.view(B, 2, C//2, f) # (B, 2, 80, F//2)
x = x.permute(0, 2, 3, 1).contiguous() # (B, 80, F//2, 2)
x = x.view(B, C//2, 2*f) # (B, 160, F)
x_mask = x_mask.unsqueeze(3).repeat(1, 1, 1, 2).view(B, 1, 2*f) # (B, 1, F//2, 1) -> (B, 1, F//2, 2) -> (B, 1, F)
x = x * x_mask # masking
return x, x_mask
class Encoder(nn.Module):
def __init__(self):
super().__init__()
self.embedding = nn.Embedding(symbol_length, 192) # (B, T) -> (B, T, 192)
nn.init.normal_(self.embedding.weight, 0.0, 192**(-0.5)) # ๊ฐ€์ค‘์น˜ ์ •๊ทœ๋ถ„ํฌ ์ดˆ๊ธฐํ™” (N(0, 0.07xx))
self.prenet = PreNet()
self.transformer_encoder = TransformerEncoder()
self.project_mean = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T)
self.project_std = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T)
self.duration_predictor = DurationPredictor()
def forward(self, text, text_len):
"""
=====inputs=====
text: (B, Max_T)
text_len: (B)
=====outputs=====
x_mean: (B, 80, T) | ํ‰๊ท , ๋…ผ๋ฌธ ์ €์ž ๊ตฌํ˜„์˜ train.py์—์„œ out_channels๋ฅผ 80์œผ๋กœ ์„ค์ •ํ•œ ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Œ.
x_std: (B, 80, T) | ํ‘œ์ค€ํŽธ์ฐจ
x_dur: (B, 1, T)
x_mask: (B, 1, T)
"""
x = self.embedding(text) * math.sqrt(192) # (B, T) -> (B, T, 192) # math.sqrt(192) = 13.xx (์ˆ˜์ •)
x = x.transpose(1, 2) # (B, T, 192) -> (B, 192, T)
# Make the x_mask
x_mask = torch.zeros_like(x[:, 0:1, :], dtype=torch.bool) # (B, 1, T)
for idx, length in enumerate(text_len):
x_mask[idx, :, :length] = True
x = self.prenet(x, x_mask) # (B, 192, T)
x = self.transformer_encoder(x, x_mask) # (B, 192, T)
# project
x_mean = self.project_mean(x) * x_mask # (B, 192, T) -> (B, 80, T)
# x_std = self.project_std(x) * x_mask # (B, 192, T) -> (B, 80, T)
##### ์•„๋ž˜๋Š” mean_only๋ฅผ ์ ์šฉํ•œ ๊ฒƒ์ž„. #####
x_std = torch.zeros_like(x_mean) # x_log_std: (B, 80, T), all zero # log std = 0์ด๋ฏ€๋กœ std = 1๋กœ ๊ณ„์‚ฐ๋จ.
# duration predictor
x_dp = torch.detach(x) # stop_gradient
x_dur = self.duration_predictor(x_dp, x_mask) # (B, 192, T) -> (B, 1, T)
return x_mean, x_std, x_dur, x_mask
class LayerNorm(nn.Module):
"""
์—ฌ๋Ÿฌ ๊ณณ์—์„œ ์ •๊ทœํ™”(Norm)๋ฅผ ์œ„ํ•ด ์‚ฌ์šฉ๋˜๋Š” ๋ชจ๋“ˆ.
nn.LayerNorm์ด ์ด๋ฏธ pytorch ์•ˆ์— ๊ตฌํ˜„๋˜์–ด ์žˆ์œผ๋‚˜, ํ•ญ์ƒ ๋งˆ์ง€๋ง‰ ์ฐจ์›์„ ์ •๊ทœํ™”ํ•œ๋‹ค.
๊ทธ๋ž˜์„œ channel์„ ๊ธฐ์ค€์œผ๋กœ ์ •๊ทœํ™”ํ•˜๋Š” LayerNorm์„ ๋”ฐ๋กœ ๊ตฌํ˜„ํ•œ๋‹ค.
"""
def __init__(self, channels):
"""
channels: ์ž…๋ ฅ ๋ฐ์ดํ„ฐ์˜ channel ์ˆ˜ | LayerNorm์€ channel ์ฐจ์›์„ ์ •๊ทœํ™”ํ•œ๋‹ค.
"""
super().__init__()
self.channels = channels
self.eps = 1e-4
self.gamma = nn.Parameter(torch.ones(channels)) # ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
self.beta = nn.Parameter(torch.zeros(channels)) # ํ•™์Šต ๊ฐ€๋Šฅํ•œ ํŒŒ๋ผ๋ฏธํ„ฐ
def forward(self, x):
"""
=====inputs=====
x: (B, channels, *) | ์ •๊ทœํ™”ํ•  ์ž…๋ ฅ ๋ฐ์ดํ„ฐ
=====outputs=====
x: (B, channels, *) | channel ์ฐจ์›์ด ์ •๊ทœํ™”๋œ ๋ฐ์ดํ„ฐ
"""
mean = torch.mean(x, dim=1, keepdim=True) # channel ์ฐจ์›(index=1)์˜ ํ‰๊ท  ๊ณ„์‚ฐ, ์ฐจ์›์„ ์œ ์ง€ํ•œ๋‹ค.
variance = torch.mean((x-mean)**2, dim=1, keepdim=True) # ๋ถ„์‚ฐ ๊ณ„์‚ฐ
x = (x - mean) * (variance + self.eps)**(-0.5) # (x - m) / sqrt(v)
n = len(x.shape)
shape = [1] * n
shape[1] = -1 # shape = [1, -1, 1] or [1, -1, 1, 1]
x = x * self.gamma.view(*shape) + self.beta.view(*shape) # y = x*gamma + beta
return x
class PreNet(nn.Module):
"""
Encoder์˜ 1๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
for i in range(3):
self.convs.append(nn.Conv1d(192, 192, kernel_size=5, padding=2)) # (B, 192, T) ์œ ์ง€
self.norms.append(LayerNorm(192)) # (B, 192, T) ์œ ์ง€
self.linear = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์œ ์ง€ | linear ์—ญํ• ์„ ํ•˜๋Š” conv
def forward(self, x, x_mask):
"""
=====inputs=====
x: (B, 192, T) | Embedding๋œ ์ž…๋ ฅ ๋ฐ์ดํ„ฐ
x_mask: (B, 1, T) | ๊ธ€์ž ๊ธธ์ด์— ๋”ฐ๋ฅธ mask (๊ธ€์ž๊ฐ€ ์žˆ์œผ๋ฉด True, ์—†์œผ๋ฉด False๋กœ ๊ตฌ์„ฑ)
=====outputs=====
x: (B, 192, T)
"""
x0 = x
for i in range(3):
x = self.convs[i](x * x_mask)
x = self.norms[i](x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear(x)
x = x0 + x # residual connection
return x
class MultiHeadAttention(nn.Module):
"""
Encoder ์ค‘ 2๋ฒˆ์งธ ๋ชจ๋“ˆ์ธ TransformerEncoder์˜ 1๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
self.n_heads = 2
self.window_size = 4
self.k_channels = 192 // self.n_heads # 96
self.linear_q = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์œ ์ง€
self.linear_k = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์œ ์ง€
self.linear_v = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์œ ์ง€
nn.init.xavier_uniform_(self.linear_q.weight)
nn.init.xavier_uniform_(self.linear_k.weight)
nn.init.xavier_uniform_(self.linear_v.weight)
relative_std = self.k_channels ** (-0.5) # 0.1xx
self.relative_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96)
self.relative_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96)
self.attention_weights = None
self.linear_out = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) ์œ ์ง€
self.dropout = nn.Dropout(0.1)
def forward(self, query, context, attention_mask, self_attention=True):
"""
=====inputs=====
query: (B, 192, T_target) | Glow-TTS์—์„œ๋Š” self-attention๋งŒ ์ด์šฉํ•˜๋ฏ€๋กœ query์™€ context๊ฐ€ ๋™์ผํ•œ ํ…์„œ x์ด๋‹ค.
context: (B, 192, T_source) | query = context || ์—ฌ๊ธฐ์—์„œ๋Š” ํŠนํžˆ T_source = T_target ์ด๋‹ค.
attention_mask: (B, 1, T, T) | x_mask.unsqueeze(2) * z_mask.unsqueeze(3)
self_attention: True/False | self_attention์ผ ๋•Œ relative position representations๋ฅผ ์ ์šฉํ•œ๋‹ค. ์—ฌ๊ธฐ์—์„œ๋Š” ํ•ญ์ƒ True์ด๋‹ค.
# ์‹ค์ œ๋กœ๋Š” query์™€ context์— ๊ฐ™์€ ํ…์„œ x๋ฅผ ์ž…๋ ฅํ•˜๋ฉด ๋œ๋‹ค.
=====outputs=====
output: (B, 192, T)
"""
query = self.linear_q(query)
key = self.linear_k(context)
value = self.linear_v(context)
B, _, T_tar = query.size()
T_src = key.size(2)
query = query.view(B, self.n_heads, self.k_channels, T_tar).transpose(2, 3)
key = key.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3)
value = value.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3)
# (B, 192, T_src) -> (B, 2, 96, T_src) -> (B, 2, T_src, 96)
scores = torch.matmul(query, key.transpose(2, 3)) / (self.k_channels ** 0.5)
# (B, 2, T_tar, 96) * (B, 2, 96, T_src) -> (B, 2, T_tar, T_src)
if self_attention: # True
# Get relative embeddings (relative_keys) (1-1)
padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0)
start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0)
end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4)
relative_keys = F.pad(self.relative_k, (0, 0, padding, padding))
# (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96)
"""
์œ„ ์ฝ”๋“œ์˜ F.pad(input, pad) ์—์„œ pad = (0, 0, padding, padding)์€ ๋‹ค์Œ์„ ์˜๋ฏธํ•œ๋‹ค.
- ์•ž์˜ (0, 0): input์˜ -1์ฐจ์›์„ ์•ž์œผ๋กœ 0, ๋’ค๋กœ 0๋งŒํผ ํŒจ๋”ฉํ•œ๋‹ค.
- ์•ž์˜ (padding, padding): input์˜ -2์ฐจ์›์„ ์•ž์œผ๋กœ padding, ๋’ค๋กœ padding๋งŒํผ ํŒจ๋”ฉํ•œ๋‹ค.
์ฆ‰, F.pad์—์„œ pad๋Š” ์—ญ์ˆœ์œผ๋กœ ์ƒ๊ฐํ•ด์ฃผ์–ด์•ผ ํ•œ๋‹ค.
"""
relative_keys = relative_keys[:, start_pos:end_pos, :] # (1, 2T-1, 96)
# Matmul with relative keys (2-1)
relative_keys = relative_keys.unsqueeze(0).transpose(2, 3) # (1, 2T-1, 96) -> (1, 1, 2T-1, 96) -> (1, 1, 96, 2T-1)
x = torch.matmul(query, relative_keys) # (B, 2, T_tar, 96) * (1, 1, 96, 2T_src-1) = (B, 2, T, 2T-1)
# self attention์—์„œ๋Š” T_tar = T_src์ด๋ฏ€๋กœ ์ด๋ฅผ ๋‹ค๋ฅด๊ฒŒ ๊ณ ๋ คํ•  ํ•„์š”๊ฐ€ ์—†๋‹ค.
# Relative position to absolute position (3-1)
T = T_tar # Absolute position to relative position์—์„œ๋„ ์“ฐ์ž„.
x = F.pad(x, (0, 1)) # (B, 2, T, 2*T-1) -> (B, 2, T, 2*T)
x = x.view(B, self.n_heads, T * 2 * T) # (B, 2, T, 2*T) -> (B, 2. 2T^2)
x = F.pad(x, (0, T-1)) # (B, 2, 2T^2 + T - 1)
x = x.view(B, self.n_heads, T+1, 2*T-1) # (B, 2, T+1, 2T-1)
relative_logits = x[:, :, :T, T-1:] # (B, 2, T, T)
# Compute scores
scores_local = relative_logits / (self.k_channels ** 0.5)
scores = scores + scores_local # (B, 2, T, T)
"""
์œ„ ์‹์€ Self-Attention with Relative Position Representations ๋…ผ๋ฌธ์˜ 5๋ฒˆ ์‹์„ ๊ตฌํ˜„ํ•œ ๊ฒƒ์ด๋‹ค.
Relative- ๋…ผ๋ฌธ: https://arxiv.org/pdf/1803.02155.pdf
"""
scores = scores.masked_fill(attention_mask == 0, -1e-4) # attention_mask๊ฐ€ 0์ธ ๊ณณ์„ -1e-4๋กœ ์ฑ„์šด๋‹ค.
attention_weights = F.softmax(scores, dim=-1) # (B, 2, T_tar, T_src) # Relative- ๋…ผ๋ฌธ์—์„œ์˜ alpha์— ํ•ด๋‹นํ•œ๋‹ค.
attention_weights = self.dropout(attention_weights) # dropoutํ•˜๋Š” ์ด์œ ๊ฐ€ ๋ฌด์—‡์ผ๊นŒ?
output = torch.matmul(attention_weights, value) # (B, 2, T_tar, T_src) * (B, 2, T_src, 96) -> (B, 2, T_tar, 96)
if self_attention: # True
# Absolute position to relative position (3-2)
x = F.pad(attention_weights, (0, T-1)) # (B, 2, T, T) -> (B, 2, T, 2T-1)
x = x.view((B, self.n_heads, T * (2*T-1))) # (B, 2, 2T^2-T)
x = F.pad(x, (T, 0)) # (B, 2, 2T^2) # ์•ž์— ํŒจ๋”ฉ
x = x.view((B, self.n_heads, T, 2*T)) # (B, 2, T, 2T)
relative_weights = x[:, :, :, 1:] # (B, 2, T, 2T-1)
# Get relative embeddings (relative_value) (1-2) # (1-1)๊ณผ ๊ฑฐ์˜ ๋™์ผ
padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0)
start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0)
end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4)
relative_values = F.pad(self.relative_v, (0, 0, padding, padding))
# (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96)
relative_values = relative_values[:, start_pos:end_pos, :] # (1, 2T-1, 96)
# Matmul with relative values (2-2)
relative_values = relative_values.unsqueeze(0) # (1, 1, 2T-1, 96)
output = output + torch.matmul(relative_weights, relative_values)
# (B, 2, T, 2T-1) * (1, 1, 2T-1, 96) = (B, 2, T, 96)
"""
์œ„ ์‹์€ Self-Attention with Relative Position Representations ๋…ผ๋ฌธ์˜ 3๋ฒˆ ์‹์„ ๊ตฌํ˜„ํ•œ ๊ฒƒ์ด๋‹ค. (๋ถ„๋ฐฐ๋ฒ•์น™ ์ด์šฉ)
Relative- ๋…ผ๋ฌธ: https://arxiv.org/pdf/1803.02155.pdf
"""
output = output.transpose(2, 3).contiguous().view(B, 192, T_tar)
# (B, 2, 96, T) -> ๋ฉ”๋ชจ๋ฆฌ์— ์—ฐ์† ๋ฐฐ์น˜ -> (B, 192, T)
self.attention_weights = attention_weights # (B, 2, T, T)
output = self.linear_out(output)
return output # (B, 192, T)
class FFN(nn.Module):
"""
Encoder ์ค‘ 2๋ฒˆ์งธ ๋ชจ๋“ˆ์ธ TransformerEncoder์˜ 2๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(192, 768, kernel_size=3, padding=1) # (B, 192, T) -> (B, 768, T)
self.relu = nn.ReLU()
self.conv2 = nn.Conv1d(768, 192, kernel_size=3, padding=1) # (B, 768, T) -> (B, 192, T)
self.dropout = nn.Dropout(0.1)
def forward(self, x, x_mask):
"""
=====inputs=====
x: (B, 192, T)
x_mask: (B, 1, T)
=====outputs=====
output: (B, 192, T)
"""
x = self.conv1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.conv2(x)
output = x * x_mask
return output
class TransformerEncoder(nn.Module):
"""
Encoder์˜ 2๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
self.attentions = nn.ModuleList()
self.norms1 = nn.ModuleList()
self.ffns = nn.ModuleList()
self.norms2 = nn.ModuleList()
for i in range(6):
self.attentions.append(MultiHeadAttention())
self.norms1.append(LayerNorm(192))
self.ffns.append(FFN())
self.norms2.append(LayerNorm(192))
self.dropout = nn.Dropout(0.1)
def forward(self, x, x_mask):
"""
=====inputs=====
x: (B, 192, T)
x_mask: (B, 1, T)
=====outputs=====
output: (B, 192, T)
"""
attention_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(3)
# (B, 1, 1, T) * (B, 1, T, 1) = (B, 1, T, T), only consist 0 or 1
for i in range(6):
x = x * x_mask
y = self.attentions[i](x, x, attention_mask)
y = self.dropout(y)
x = x + y # residual connection
x = self.norms1[i](x) # (B, 192, T) ์œ ์ง€
y = self.ffns[i](x, x_mask)
y = self.dropout(y)
x = x + y # residual connection
x = self.norms2[i](x)
output = x * x_mask
return output # (B, 192, T)
class DurationPredictor(nn.Module):
"""
Encoder์˜ 3๋ฒˆ์งธ ๋ชจ๋“ˆ
"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv1d(192, 256, kernel_size=3, padding=1) # (B, 192, T) -> (B, 256, T)
self.norm1 = LayerNorm(256)
self.conv2 = nn.Conv1d(256, 256, kernel_size=3, padding=1) # (B, 256, T) -> (B, 256, T)
self.norm2 = LayerNorm(256)
self.linear = nn.Conv1d(256, 1, kernel_size=1) # (B, 256, T) -> (B, 1, T)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, x, x_mask):
"""
=====inputs=====
x: (B, 192, T)
x_mask: (B, 1, T)
=====outputs=====
output: (B, 1, T)
"""
x = self.conv1(x * x_mask) # (B, 192, T) -> (B, 256, T)
x = self.relu(x)
x = self.norm1(x)
x = self.dropout(x)
x = self.conv2(x * x_mask) # (B, 256, T) -> (B, 256, T)
x = self.relu(x)
x = self.norm2(x)
x = self.dropout(x)
x = self.linear(x * x_mask) # (B, 256, T) -> (B, 1, T)
output = x * x_mask
return output