Spaces:
Running
Running
Create Tmodel.py
#26
by
marigold334
- opened
Tmodel.py
ADDED
@@ -0,0 +1,894 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils import weight_norm
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
|
8 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
9 |
+
symbol_length = 73
|
10 |
+
|
11 |
+
class GlowTTS(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.encoder = Encoder()
|
15 |
+
self.decoder = Decoder()
|
16 |
+
|
17 |
+
def forward(self, text, text_len, mel=None, mel_len=None, inference=False):
|
18 |
+
"""
|
19 |
+
=====inputs=====
|
20 |
+
text: (B, T)
|
21 |
+
text_len: (B) list
|
22 |
+
mel: (B, 80, F)
|
23 |
+
mel_len: (B) list
|
24 |
+
inference: True/False
|
25 |
+
=====outputs=====
|
26 |
+
(tuple) (z, z_mean, z_log_std, log_det, z_mask)
|
27 |
+
z(training) or y(inference): (B, 80, F) | z: latent representation, y: mel-spectrogram
|
28 |
+
z_mean: (B, 80, F)
|
29 |
+
z_log_std: (B, 80, F)
|
30 |
+
log_det: (B) or None
|
31 |
+
z_mask: (B, 1, F)
|
32 |
+
(tuple) (x_mean, x_log_std, x_mask)
|
33 |
+
x_mean: (B, 80, T)
|
34 |
+
x_log_std: (B, 80, T)
|
35 |
+
x_mask: (B, 1, T)
|
36 |
+
(tuple) (attention_alignment, x_log_dur, log_d)
|
37 |
+
attention_alignment: (B, T, F)
|
38 |
+
x_log_dur: (B, 1, T) | 추측한 duration의 log scale
|
39 |
+
log_d: (B, 1, T) | 적절하다고 추측한 alignment에서의 duration의 log scale
|
40 |
+
"""
|
41 |
+
x_mean, x_log_std, x_log_dur, x_mask = self.encoder(text, text_len)
|
42 |
+
# x_std, x_dur 에 log를 붙인 이유는, 논문 저자의 구현에서는 log가 취해진 값으로 간주하기 때문이다.
|
43 |
+
y, y_len = mel, mel_len
|
44 |
+
|
45 |
+
if not inference: # training
|
46 |
+
y_max_len = y.size(2)
|
47 |
+
else: # inference
|
48 |
+
dur = torch.exp(x_log_dur) * x_mask # (B, 1, T)
|
49 |
+
ceil_dur = torch.ceil(dur) # (B, 1, T)
|
50 |
+
y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B)
|
51 |
+
# ceil_dur을 [1, 2] 축에 대해 sum한 뒤 최솟값이 1이상이 되도록 설정. 정수 long 타입으로 반환한다.
|
52 |
+
y_max_len = None
|
53 |
+
|
54 |
+
# preprocessing
|
55 |
+
if y_max_len is not None:
|
56 |
+
y_max_len = (y_max_len // 2) * 2 # 홀수면 1을 빼서 짝수로 만든다.
|
57 |
+
y = y[:, :, :y_max_len] # y_max_len에 맞게 y를 조정
|
58 |
+
y_len = (y_len // 2) * 2 # y_len이 홀수이면 1을 빼서 짝수로 만든다.
|
59 |
+
|
60 |
+
# make the z_mask
|
61 |
+
B = len(y_len)
|
62 |
+
temp_max = max(y_len)
|
63 |
+
z_mask = torch.zeros((B, 1, temp_max), dtype=torch.bool).to(device) # (B, 1, F)
|
64 |
+
for idx, length in enumerate(y_len):
|
65 |
+
z_mask[idx, :, :length] = True
|
66 |
+
|
67 |
+
# make the attention_mask
|
68 |
+
attention_mask = x_mask.unsqueeze(3) * z_mask.unsqueeze(2) # (B, 1, T, 1) * (B, 1, 1, F) = (B, 1, T, F)
|
69 |
+
# 주의: Encoder의 attention_mask와는 다른 mask임.
|
70 |
+
|
71 |
+
if not inference: # training
|
72 |
+
z, log_det = self.decoder(y, z_mask, reverse=False)
|
73 |
+
with torch.no_grad():
|
74 |
+
x_std_squared_root = torch.exp(-2 * x_log_std) # (B, 80, T)
|
75 |
+
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_log_std, [1]).unsqueeze(-1) # [(B, T, F)
|
76 |
+
logp2 = torch.matmul(x_std_squared_root.transpose(1, 2), -0.5 * (z ** 2)) # [(B, T, 80) * (B, 80, F) = (B, T, F)
|
77 |
+
logp3 = torch.matmul((x_mean * x_std_squared_root).transpose(1,2), z) # (B, T, 80) * (B, 80, F) = (B, T, F)
|
78 |
+
logp4 = torch.sum(-0.5 * (x_mean ** 2) * x_std_squared_root, [1]).unsqueeze(-1) # (B, T, F)
|
79 |
+
logp = logp1 + logp2 + logp3 + logp4 # (B, T, F)
|
80 |
+
"""
|
81 |
+
logp는 normal distribution N(x_mean, x_std)의 maximum log-likelihood이다.
|
82 |
+
sum(log(N(z;x_mean, x_std)))를 정규분포 식을 이용하여 분배법칙으로 풀어내면 위와 같은 식이 도출된다.
|
83 |
+
"""
|
84 |
+
attention_alignment = maximum_path(logp, attention_mask.squeeze(1)).detach() # alignment (B, T, F)
|
85 |
+
|
86 |
+
z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
|
87 |
+
z_mean = z_mean.transpose(1, 2) # (B, 80, F)
|
88 |
+
z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
|
89 |
+
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
|
90 |
+
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment에서 형성된 duration의 log scale
|
91 |
+
return (z, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
|
92 |
+
|
93 |
+
else: # inference
|
94 |
+
# generate_path (make attention_alignment using ceil(x_dur))
|
95 |
+
attention_alignment = generate_path(ceil_dur.squeeze(1), attention_mask.squeeze(1)) # (B, T, F)
|
96 |
+
z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
|
97 |
+
z_mean = z_mean.transpose(1, 2) # (B, 80, F)
|
98 |
+
z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
|
99 |
+
z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
|
100 |
+
log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment에서 형성된 duration의 log scale
|
101 |
+
|
102 |
+
z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean)) * z_mask # z(latent representation) 생성
|
103 |
+
y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram 생성
|
104 |
+
return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
|
105 |
+
|
106 |
+
##### 아래 논문의 구현이 훨씬 빠르다. 이 논문 구현을 보고 위의 구현을 변경할 필요가 있다. #####
|
107 |
+
def maximum_path(value, mask, max_neg_val=-np.inf):
|
108 |
+
""" Numpy-friendly version. It's about 4 times faster than torch version.
|
109 |
+
value: [b, t_x, t_y]
|
110 |
+
mask: [b, t_x, t_y]
|
111 |
+
"""
|
112 |
+
value = value * mask
|
113 |
+
|
114 |
+
device = value.device
|
115 |
+
dtype = value.dtype
|
116 |
+
value = value.cpu().detach().numpy()
|
117 |
+
mask = mask.cpu().detach().numpy().astype(bool)
|
118 |
+
|
119 |
+
b, t_x, t_y = value.shape
|
120 |
+
direction = np.zeros(value.shape, dtype=np.int64)
|
121 |
+
v = np.zeros((b, t_x), dtype=np.float32)
|
122 |
+
x_range = np.arange(t_x, dtype=np.float32).reshape(1,-1)
|
123 |
+
for j in range(t_y):
|
124 |
+
v0 = np.pad(v, [[0,0],[1,0]], mode="constant", constant_values=max_neg_val)[:, :-1]
|
125 |
+
v1 = v
|
126 |
+
max_mask = (v1 >= v0)
|
127 |
+
v_max = np.where(max_mask, v1, v0)
|
128 |
+
direction[:, :, j] = max_mask
|
129 |
+
|
130 |
+
index_mask = (x_range <= j)
|
131 |
+
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
|
132 |
+
direction = np.where(mask, direction, 1)
|
133 |
+
|
134 |
+
path = np.zeros(value.shape, dtype=np.float32)
|
135 |
+
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
|
136 |
+
index_range = np.arange(b)
|
137 |
+
for j in reversed(range(t_y)):
|
138 |
+
path[index_range, index, j] = 1
|
139 |
+
index = index + direction[index_range, index, j] - 1
|
140 |
+
path = path * mask.astype(np.float32)
|
141 |
+
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
142 |
+
return path
|
143 |
+
|
144 |
+
|
145 |
+
def generate_path(duration, mask):
|
146 |
+
"""
|
147 |
+
duration: [b, t_x]
|
148 |
+
mask: [b, t_x, t_y]
|
149 |
+
"""
|
150 |
+
device = duration.device
|
151 |
+
|
152 |
+
b, t_x, t_y = mask.shape # (B, T, F)
|
153 |
+
cum_duration = torch.cumsum(duration, 1) # 누적합, (B, T)
|
154 |
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) # (B, T, F)
|
155 |
+
|
156 |
+
cum_duration_flat = cum_duration.view(b * t_x) # (B*T)
|
157 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) # (B*T, F)
|
158 |
+
path = path.view(b, t_x, t_y) # (B, T, F)
|
159 |
+
path = path.to(torch.float32)
|
160 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:,:-1] # (B, T, F) # T의 차원 맨 앞을 -1한다.
|
161 |
+
path = path * mask
|
162 |
+
return path
|
163 |
+
|
164 |
+
def sequence_mask(length, max_length=None):
|
165 |
+
if max_length is None:
|
166 |
+
max_length = length.max()
|
167 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
168 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
169 |
+
|
170 |
+
def convert_pad_shape(pad_shape):
|
171 |
+
l = pad_shape[::-1] # [[0, 0], [p, p], [0, 0]]
|
172 |
+
pad_shape = [item for sublist in l for item in sublist] # [0, 0, p, p, 0, 0]
|
173 |
+
return pad_shape
|
174 |
+
|
175 |
+
def MAS(path, logp, T_max, F_max):
|
176 |
+
"""
|
177 |
+
Glow-TTS의 모듈인 maximum_path의 모듈
|
178 |
+
MAS 알고리즘을 수행하는 함수이다.
|
179 |
+
=====inputs=====
|
180 |
+
path: (T, F)
|
181 |
+
logp: (T, F)
|
182 |
+
T_max: (1)
|
183 |
+
F_max: (1)
|
184 |
+
=====outputs=====
|
185 |
+
path: (T, F) | 0과 1로 구성된 alignment
|
186 |
+
"""
|
187 |
+
neg_inf = -1e9 # negative infinity
|
188 |
+
# forward
|
189 |
+
for j in range(F_max):
|
190 |
+
for i in range(max(0, T_max + j - F_max), min(T_max, j + 1)): # 평행사변형을 생각하라.
|
191 |
+
# Q_i_j-1 (current)
|
192 |
+
if i == j:
|
193 |
+
Q_cur = neg_inf
|
194 |
+
else:
|
195 |
+
Q_cur = logp[i, j-1] # j=0이면 i도 0이므로 j-1을 사용해도 된다.
|
196 |
+
|
197 |
+
# Q_i-1_j-1 (previous)
|
198 |
+
if i==0:
|
199 |
+
if j==0:
|
200 |
+
Q_prev = 0. # i=0, j=0인 경우에는 logp 값만 반영해야 한다.
|
201 |
+
else:
|
202 |
+
Q_prev = neg_inf # i=0인 경우에는 Q_i-1_j-1을 반영하지 않아야 한다.
|
203 |
+
else:
|
204 |
+
Q_prev = logp[i-1, j-1]
|
205 |
+
|
206 |
+
# logp에 Q를 갱신한다.
|
207 |
+
logp[i, j] = max(Q_cur, Q_prev) + logp[i, j]
|
208 |
+
|
209 |
+
# backtracking
|
210 |
+
idx = T_max - 1
|
211 |
+
for j in range(F_max-1, -1, -1): # F_max-1부터 -1까지(-1 포함 없이 0까지) -1씩 감소
|
212 |
+
path[idx, j] = 1
|
213 |
+
if idx != 0:
|
214 |
+
if (logp[idx, j-1] < logp[idx-1, j-1]) or (idx == j):
|
215 |
+
idx -= 1
|
216 |
+
|
217 |
+
return path
|
218 |
+
|
219 |
+
|
220 |
+
def maximum_path(logp, attention_mask):
|
221 |
+
"""
|
222 |
+
Glow-TTS에 사용되�� 모듈
|
223 |
+
MAS를 사용하여 alignment를 찾아주는 역할을 한다.
|
224 |
+
논문 저자 구현에서는 cpython을 이용하여 병렬 처리를 구현한 듯 하나
|
225 |
+
여기에서는 python만을 이용하여 구현하였다.
|
226 |
+
=====inputs=====
|
227 |
+
logp: (B, T, F) | N(x_mean, x_std)의 log-likelihood
|
228 |
+
attention_mask: (B, T, F)
|
229 |
+
=====outputs=====
|
230 |
+
path: (B, T, F) | alignment
|
231 |
+
"""
|
232 |
+
B = logp.shape[0]
|
233 |
+
|
234 |
+
logp = logp * attention_mask
|
235 |
+
# 계산은 CPU에서 실행되도록 하기 위해 기존의 device를 저장하고 .cpu().numpy()를 한다.
|
236 |
+
logp_device = logp.device
|
237 |
+
logp_type = logp.dtype
|
238 |
+
logp = logp.data.cpu().numpy().astype(np.float32)
|
239 |
+
attention_mask = attention_mask.data.cpu().numpy()
|
240 |
+
|
241 |
+
path = np.zeros_like(logp).astype(np.int32) # (B, T, F)
|
242 |
+
T_max = attention_mask.sum(1)[:, 0].astype(np.int32) # (B)
|
243 |
+
F_max = attention_mask.sum(2)[:, 0].astype(np.int32) # (B)
|
244 |
+
|
245 |
+
# MAS 알고리즘
|
246 |
+
for idx in range(B):
|
247 |
+
path[idx] = MAS(path[idx], logp[idx], T_max[idx], F_max[idx]) # (T, F)
|
248 |
+
return torch.from_numpy(path).to(device=logp_device, dtype=logp_type)
|
249 |
+
|
250 |
+
def generate_path(ceil_dur, attention_mask):
|
251 |
+
"""
|
252 |
+
Glow-TTS에 사용되는 모듈
|
253 |
+
inference 과정에서 alignment를 만들어낸다.
|
254 |
+
=====input=====
|
255 |
+
ceil_dur: (B, T) | 추론한 duration에 ceil 연산한 것 | ex) [[2, 1, 2, 2, ...], [1, 2, 1, 3, ...], ...]
|
256 |
+
attention_mask: (B, T, F)
|
257 |
+
=====output=====
|
258 |
+
path: (B, T, F) | alignment
|
259 |
+
"""
|
260 |
+
B, T, Frame = attention_mask.shape
|
261 |
+
cum_dur = torch.cumsum(ceil_dur, 1)
|
262 |
+
cum_dur = cum_dur.to(torch.int32) # (B, T) | 누적합 | ex) [[2, 3, 5, 7, ...], [1, 3, 4, 7, ...], ...]
|
263 |
+
path = torch.zeros(B, T, Frame).to(ceil_dur.device) # (B, T, F) | all False(0)
|
264 |
+
|
265 |
+
# make the sequence_mask
|
266 |
+
for b, batch_cum_dur in enumerate(cum_dur):
|
267 |
+
for t, each_cum_dur in enumerate(batch_cum_dur):
|
268 |
+
path[b, t, :each_cum_dur] = torch.ones((1, 1, each_cum_dur)).to(ceil_dur.device)
|
269 |
+
# cum_dur로부터 True(1)를 path에 새겨넣는다.
|
270 |
+
path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1] # (B, T, F)
|
271 |
+
"""
|
272 |
+
ex) batch를 잠시 제외해두고 예시를 든다.
|
273 |
+
[[1, 1, 0, 0, 0, 0, 0], [[0, 0, 0, 0, 0, 0, 0], [[1, 1, 0, 0, 0, 0, 0],
|
274 |
+
[1, 1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], = [0, 0, 1, 0, 0, 0, 0],
|
275 |
+
[1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0],
|
276 |
+
[1, 1, 1, 1, 1, 1, 1]] [1, 1, 1, 1, 1, 0, 0]] [0, 0, 0, 0, 0, 1, 1]]
|
277 |
+
"""
|
278 |
+
path = path * attention_mask
|
279 |
+
return path
|
280 |
+
|
281 |
+
class Decoder(nn.Module):
|
282 |
+
def __init__(self):
|
283 |
+
super().__init__()
|
284 |
+
self.flows = nn.ModuleList()
|
285 |
+
for i in range(12):
|
286 |
+
self.flows.append(ActNorm())
|
287 |
+
self.flows.append(InvertibleConv())
|
288 |
+
self.flows.append(AffineCouplingLayer())
|
289 |
+
|
290 |
+
def forward(self, x, x_mask, reverse=False):
|
291 |
+
"""
|
292 |
+
=====inputs=====
|
293 |
+
x: (B, 80, F) | mel-spectrogram(Direct) OR latent representation(Reverse)
|
294 |
+
x_mask: (B, 1, F)
|
295 |
+
=====outputs=====
|
296 |
+
z: (B, 80, F) | latent representation(Direct) OR mel-spectrogram(Reverse)
|
297 |
+
total_log_det: (B) or None | log determinant
|
298 |
+
"""
|
299 |
+
if not reverse:
|
300 |
+
flows = self.flows
|
301 |
+
total_log_det = 0
|
302 |
+
else:
|
303 |
+
flows = reversed(self.flows)
|
304 |
+
total_log_det = None
|
305 |
+
|
306 |
+
x, x_mask = Squeeze(x, x_mask) # (B, 80, F) -> (B, 160, F//2) | (B, 1, F) -> (B, 1, F//2)
|
307 |
+
|
308 |
+
for f in flows:
|
309 |
+
if not reverse:
|
310 |
+
x, log_det = f(x, x_mask, reverse=reverse)
|
311 |
+
total_log_det += log_det
|
312 |
+
else:
|
313 |
+
x, _ = f(x, x_mask, reverse=reverse)
|
314 |
+
|
315 |
+
x, x_mask = Unsqueeze(x, x_mask) # (B, 160, F//2) -> (B, 80, F) | (B, 1, F//2) -> (B, 1, F)
|
316 |
+
|
317 |
+
return x, total_log_det
|
318 |
+
|
319 |
+
"""
|
320 |
+
Decoder는 Glow: Generative Flow with Invertible 1×1 Convolutions 논문의 기본 구조를 따라간다.
|
321 |
+
Glow 논문: https://arxiv.org/pdf/1807.03039.pdf
|
322 |
+
"""
|
323 |
+
def Squeeze(x, x_mask):
|
324 |
+
"""
|
325 |
+
Decoder의 preprocessing
|
326 |
+
=====inputs=====
|
327 |
+
x: (B, 80, F) | mel_spectrogram or latent representation
|
328 |
+
x_mask: (B, 1, F)
|
329 |
+
=====outputs=====
|
330 |
+
x: (B, 160, F//2) | F//2 = [F/2] ([]: 가우스 기호)
|
331 |
+
x_mask: (B, 160, F//2)
|
332 |
+
"""
|
333 |
+
B, C, F = x.size()
|
334 |
+
x = x[:, :, :(F//2)*2] # F가 홀수이면 맨 뒤 한 frame을 버림.
|
335 |
+
x = x.view(B, C, F//2, 2) # (B, 80, F//2, 2)
|
336 |
+
x = x.permute(0, 3, 1, 2).contiguous() # (B, 2, 80, F//2)
|
337 |
+
x = x.view(B, C*2, F//2) # (B, 160, F//2)
|
338 |
+
|
339 |
+
x_mask = x_mask[:, :, 1::2] # (B, 1, F//2) frame을 1부터 한칸씩 건너뛴다.
|
340 |
+
x = x * x_mask # masking
|
341 |
+
return x, x_mask
|
342 |
+
|
343 |
+
class ActNorm(nn.Module):
|
344 |
+
"""
|
345 |
+
Decoder의 1번째 모듈
|
346 |
+
"""
|
347 |
+
def __init__(self):
|
348 |
+
super().__init__()
|
349 |
+
self.log_s = nn.Parameter(torch.zeros(1, 160, 1)) # Glow 논문의 s에서 log를 취한 것이다. 즉, log[s]
|
350 |
+
self.bias = nn.Parameter(torch.zeros(1, 160, 1))
|
351 |
+
|
352 |
+
def forward(self, x, x_mask, reverse=False):
|
353 |
+
"""
|
354 |
+
=====inputs=====
|
355 |
+
x: (B, 160, F//2) | mel_spectrogram features
|
356 |
+
x_mask: (B, 1, F//2) | mel_spectrogram features의 mask. (Decoder의 Squeeze에서 변형됨.)
|
357 |
+
=====outputs=====
|
358 |
+
z: (B, 160, F//2)
|
359 |
+
log_det: (B) or None | log_determinant, reverse=True이면 None 반환
|
360 |
+
"""
|
361 |
+
x_len = torch.sum(x_mask, [1, 2]) # (B) | 1, 2차원의 값을 더한다. cf. [1, 2] 대신 [2]만 사용하면 shape가 (B, 1)이 된다.
|
362 |
+
|
363 |
+
if not reverse:
|
364 |
+
z = (x * torch.exp(self.log_s) + self.bias) * x_mask # function & masking
|
365 |
+
log_det = x_len * torch.sum(self.log_s) # log_determinant
|
366 |
+
# Glow 논문의 Table 1을 확인하라. log_s를 log[s]라 볼 수 있다.
|
367 |
+
# determinant 대신 log_determinant를 사용하는 이유는 det보다 작은 수치와 적은 계산량 때문으로 추측된다.
|
368 |
+
else:
|
369 |
+
z = ((x - self.bias) / torch.exp(self.log_s)) * x_mask # inverse function & masking
|
370 |
+
log_det = None
|
371 |
+
|
372 |
+
return z, log_det
|
373 |
+
|
374 |
+
class InvertibleConv(nn.Module):
|
375 |
+
"""
|
376 |
+
Decoder의 2번째 모듈
|
377 |
+
"""
|
378 |
+
def __init__(self):
|
379 |
+
super().__init__()
|
380 |
+
Q = torch.linalg.qr(torch.FloatTensor(4, 4).normal_())[0] # (4, 4)
|
381 |
+
"""
|
382 |
+
torch.FloatTensor(4, 4).normal_(): 정규분포 N(0, 1)에서 무작위로 추출한 4x4 matrix
|
383 |
+
Q, R = torch.linalg.qr(W): QR분해 | Q: 직교 행렬, R: upper traiangular 행렬 cf. det(Q) = 1 or -1
|
384 |
+
"""
|
385 |
+
if torch.det(Q) < 0:
|
386 |
+
Q[:, 0] = -1 * Q[:, 0] # 0번째 열의 부호를 바꿔서 det(Q) = -1로 만든다.
|
387 |
+
self.W = nn.Parameter(Q)
|
388 |
+
|
389 |
+
def forward(self, x, x_mask, reverse=False):
|
390 |
+
"""
|
391 |
+
=====inputs=====
|
392 |
+
x: (B, 160, F//2)
|
393 |
+
x_mask: (B, 1, F//2)
|
394 |
+
=====outputs=====
|
395 |
+
z: (B, 160, F//2)
|
396 |
+
log_det: (B) or None
|
397 |
+
"""
|
398 |
+
B, C, f = x.size() # B, 160, F//2
|
399 |
+
x_len = torch.sum(x_mask, [1, 2]) # (B)
|
400 |
+
|
401 |
+
# channel mixing
|
402 |
+
x = x.view(B, 2, C//4, 2, f) # (B, 2, 40, 2, F//2)
|
403 |
+
x = x.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 2, 40, F//2)
|
404 |
+
x = x.view(B, 4, C//4, f) # (B, 4, 40, F//2)
|
405 |
+
|
406 |
+
# 편의상 log_det부터 구한다.
|
407 |
+
if not reverse:
|
408 |
+
weight = self.W
|
409 |
+
log_det = (C/4) * x_len * torch.logdet(self.W) # (B) | torch.logdet(W): log(det(W))
|
410 |
+
# height = C/4, width = x_len 인 상황임을 고려하면 Glow 논문의 log_determinant 식과 같다.
|
411 |
+
else:
|
412 |
+
weight = torch.linalg.inv(self.W) # inverse matrix
|
413 |
+
log_det = None
|
414 |
+
|
415 |
+
weight = weight.view(4, 4, 1, 1)
|
416 |
+
z = F.conv2d(x, weight) # (B, 4, 40, F//2) * (4, 4, 1, 1) -> (B, 4, 40, F//2)
|
417 |
+
"""
|
418 |
+
F.conv2d(x, weight)의 convolution 연산은 다음과 같이 생각해야 한다.
|
419 |
+
(B, 4, 40, F//2): (batch_size, in_channels, height, width)
|
420 |
+
(4, 4, 1, 1): (out_channels, in_channels/groups, kernel_height, kernel_width)
|
421 |
+
|
422 |
+
즉, nn.Conv2d(4, 4, kernel_size=(1, 1))인 상황에 가중치를 준 것이다.
|
423 |
+
"""
|
424 |
+
|
425 |
+
# channel unmixing
|
426 |
+
z = z.view(B, 2, 2, C//4, f) # (B, 4, 40, F//2) -> (B, 2, 2, 40, F//2)
|
427 |
+
z = z.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 40, 2, F//2)
|
428 |
+
z = z.view(B, C, f) * x_mask # (B, 160, F//2) & masking
|
429 |
+
return z, log_det
|
430 |
+
|
431 |
+
class WN(nn.Module):
|
432 |
+
"""
|
433 |
+
Decoder의 3번째 모듈인 AffineCouplingLayer의 모듈
|
434 |
+
|
435 |
+
해당 구조는 WAVEGLOW: A FLOW-BASED GENERATIVE NETWORK FOR SPEECH SYNTHESIS 로부터 제안되었다.
|
436 |
+
WaveGlow 논문: https://arxiv.org/pdf/1811.00002.pdf
|
437 |
+
"""
|
438 |
+
def __init__(self, dilation_rate=1):
|
439 |
+
super().__init__()
|
440 |
+
self.in_layers = nn.ModuleList()
|
441 |
+
self.res_skip_layers = nn.ModuleList()
|
442 |
+
|
443 |
+
for i in range(4):
|
444 |
+
dilation = dilation_rate ** i # NVIDIA WaveGlow에서는 dilation_rate=2이지만, 여기에서는 1이므로 의미는 없다.
|
445 |
+
in_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=5, dilation=dilation,
|
446 |
+
padding=((5-1) * dilation)//2)) # (B, 192, F//2) -> (B, 2*192, F//2)
|
447 |
+
self.in_layers.append(in_layer)
|
448 |
+
|
449 |
+
if i < 3:
|
450 |
+
res_skip_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=1)) # (B, 192, F//2) -> (B, 2*192, F//2)
|
451 |
+
else:
|
452 |
+
res_skip_layer = weight_norm(nn.Conv1d(192, 192, kernel_size=1)) # (B, 192, F//2) -> (B, 192, F//2)
|
453 |
+
self.res_skip_layers.append(res_skip_layer)
|
454 |
+
|
455 |
+
self.dropout = nn.Dropout(0.05)
|
456 |
+
|
457 |
+
def forward(self, x, x_mask):
|
458 |
+
"""
|
459 |
+
=====inputs=====
|
460 |
+
x: (B, 192, F//2)
|
461 |
+
x_mask: (B, 1, F//2)
|
462 |
+
=====outputs=====
|
463 |
+
output: (B, 192, F//2)
|
464 |
+
"""
|
465 |
+
output = torch.zeros_like(x) # (B, 192, F//2) all zeros
|
466 |
+
|
467 |
+
for i in range(4):
|
468 |
+
x_in = self.in_layers[i](x) # (B, 192, F//2) -> (B, 2*192, F//2)
|
469 |
+
x_in = self.dropout(x_in) # dropout
|
470 |
+
|
471 |
+
# fused add tanh sigmoid multiply
|
472 |
+
tanh_act = torch.tanh(x_in[:, :192, :]) # (B, 192, F//2)
|
473 |
+
sigmoid_act = torch.sigmoid(x_in[:, 192:, :]) # (B, 192, F//2)
|
474 |
+
|
475 |
+
acts = sigmoid_act * tanh_act # (B, 192, F//2)
|
476 |
+
|
477 |
+
x_out = self.res_skip_layers[i](acts) # (B, 192, F//2) -> (B, 2*192, F//2) or [last](B, 192, F//2)
|
478 |
+
if i < 3:
|
479 |
+
x = (x + x_out[:, :192, :]) * x_mask # residual connection & masking
|
480 |
+
output += x_out[:, 192:, :] # add output
|
481 |
+
else:
|
482 |
+
output += x_out # (B, 192, F//2)
|
483 |
+
|
484 |
+
output = output * x_mask # masking
|
485 |
+
return output
|
486 |
+
|
487 |
+
class AffineCouplingLayer(nn.Module):
|
488 |
+
"""
|
489 |
+
Decoder의 3번째 모듈
|
490 |
+
"""
|
491 |
+
def __init__(self):
|
492 |
+
super().__init__()
|
493 |
+
self.start_conv = weight_norm(nn.Conv1d(160//2, 192, kernel_size=1)) # (B, 80, F//2) -> (B, 192, F//2)
|
494 |
+
self.wn = WN()
|
495 |
+
self.end_conv = nn.Conv1d(192, 160, kernel_size=1) # (B, 192, F//2) -> (B, 160, F//2)
|
496 |
+
# end_conv의 초기 가중치를 0으로 설정하는 것이 처음에 학습하지 않는 역할을 하며, 이는 학습 안정화에 도움이 된다.
|
497 |
+
self.end_conv.weight.data.zero_() # weight를 0으로 초기화
|
498 |
+
self.end_conv.bias.data.zero_() # bias를 0으로 초기화
|
499 |
+
|
500 |
+
def forward(self, x, x_mask, reverse=False):
|
501 |
+
"""
|
502 |
+
=====inputs=====
|
503 |
+
x: (B, 160, F//2)
|
504 |
+
x_mask: (B, 1, F//2)
|
505 |
+
=====outputs=====
|
506 |
+
z: (B, 160, F//2)
|
507 |
+
log_det: (B) or None
|
508 |
+
"""
|
509 |
+
B, C, f = x.size() # B, 160, F//2
|
510 |
+
x_0, x_1 = x[:, :C//2, :], x[:, C//2:, :] # split: (B, 80, F//2) x2
|
511 |
+
|
512 |
+
x = self.start_conv(x_0) * x_mask # (B, 80, F//2) -> (B, 192, F//2) & masking
|
513 |
+
x = self.wn(x, x_mask) # (B, 192, F//2)
|
514 |
+
out = self.end_conv(x) # (B, 192, F//2) -> (B, 160, F//2)
|
515 |
+
|
516 |
+
z_0 = x_0 # (B, 80, F//2)
|
517 |
+
m = out[:, :C//2, :] # (B, 80, F//2)
|
518 |
+
log_s = out[:, C//2:, :] # (B, 80, F//2)
|
519 |
+
|
520 |
+
if not reverse:
|
521 |
+
z_1 = (torch.exp(log_s) * x_1 + m) * x_mask # (B, 80, F//2) | function & masking
|
522 |
+
log_det = torch.sum(log_s * x_mask, [1, 2]) # (B)
|
523 |
+
else:
|
524 |
+
z_1 = (x_1 - m) / torch.exp(log_s) * x_mask # (B, 80, F//2) | inverse function & masking
|
525 |
+
log_det = None
|
526 |
+
|
527 |
+
z = torch.cat([z_0, z_1], dim=1) # (B, 160, F//2)
|
528 |
+
return z, log_det
|
529 |
+
|
530 |
+
def Unsqueeze(x, x_mask):
|
531 |
+
"""
|
532 |
+
Decoder의 postprocessing
|
533 |
+
=====inputs=====
|
534 |
+
x: (B, 160, F//2)
|
535 |
+
x_mask: (B, 1, F//2)
|
536 |
+
=====outputs=====
|
537 |
+
x: (B, 80, F)
|
538 |
+
x_mask: (B, 1, F)
|
539 |
+
"""
|
540 |
+
B, C, f = x.size() # B, 160, F//2
|
541 |
+
x = x.view(B, 2, C//2, f) # (B, 2, 80, F//2)
|
542 |
+
x = x.permute(0, 2, 3, 1).contiguous() # (B, 80, F//2, 2)
|
543 |
+
x = x.view(B, C//2, 2*f) # (B, 160, F)
|
544 |
+
|
545 |
+
x_mask = x_mask.unsqueeze(3).repeat(1, 1, 1, 2).view(B, 1, 2*f) # (B, 1, F//2, 1) -> (B, 1, F//2, 2) -> (B, 1, F)
|
546 |
+
x = x * x_mask # masking
|
547 |
+
return x, x_mask
|
548 |
+
|
549 |
+
class Encoder(nn.Module):
|
550 |
+
def __init__(self):
|
551 |
+
super().__init__()
|
552 |
+
self.embedding = nn.Embedding(symbol_length, 192) # (B, T) -> (B, T, 192)
|
553 |
+
nn.init.normal_(self.embedding.weight, 0.0, 192**(-0.5)) # 가중치 정규분포 초기화 (N(0, 0.07xx))
|
554 |
+
|
555 |
+
self.prenet = PreNet()
|
556 |
+
self.transformer_encoder = TransformerEncoder()
|
557 |
+
self.project_mean = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T)
|
558 |
+
self.project_std = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T)
|
559 |
+
|
560 |
+
self.duration_predictor = DurationPredictor()
|
561 |
+
|
562 |
+
def forward(self, text, text_len):
|
563 |
+
"""
|
564 |
+
=====inputs=====
|
565 |
+
text: (B, Max_T)
|
566 |
+
text_len: (B)
|
567 |
+
=====outputs=====
|
568 |
+
x_mean: (B, 80, T) | 평균, 논문 저자 구현의 train.py에서 out_channels를 80으로 설정한 것을 알 수 있음.
|
569 |
+
x_std: (B, 80, T) | 표준편차
|
570 |
+
x_dur: (B, 1, T)
|
571 |
+
x_mask: (B, 1, T)
|
572 |
+
"""
|
573 |
+
x = self.embedding(text) * math.sqrt(192) # (B, T) -> (B, T, 192) # math.sqrt(192) = 13.xx (수정)
|
574 |
+
x = x.transpose(1, 2) # (B, T, 192) -> (B, 192, T)
|
575 |
+
|
576 |
+
# Make the x_mask
|
577 |
+
x_mask = torch.zeros_like(x[:, 0:1, :], dtype=torch.bool) # (B, 1, T)
|
578 |
+
for idx, length in enumerate(text_len):
|
579 |
+
x_mask[idx, :, :length] = True
|
580 |
+
|
581 |
+
x = self.prenet(x, x_mask) # (B, 192, T)
|
582 |
+
x = self.transformer_encoder(x, x_mask) # (B, 192, T)
|
583 |
+
|
584 |
+
# project
|
585 |
+
x_mean = self.project_mean(x) * x_mask # (B, 192, T) -> (B, 80, T)
|
586 |
+
# x_std = self.project_std(x) * x_mask # (B, 192, T) -> (B, 80, T)
|
587 |
+
##### 아래는 mean_only를 적용한 것임. #####
|
588 |
+
x_std = torch.zeros_like(x_mean) # x_log_std: (B, 80, T), all zero # log std = 0이므로 std = 1로 계산됨.
|
589 |
+
|
590 |
+
# duration predictor
|
591 |
+
x_dp = torch.detach(x) # stop_gradient
|
592 |
+
x_dur = self.duration_predictor(x_dp, x_mask) # (B, 192, T) -> (B, 1, T)
|
593 |
+
|
594 |
+
return x_mean, x_std, x_dur, x_mask
|
595 |
+
|
596 |
+
class LayerNorm(nn.Module):
|
597 |
+
"""
|
598 |
+
여러 곳에서 정규화(Norm)를 위해 사용되는 모듈.
|
599 |
+
|
600 |
+
nn.LayerNorm이 이미 pytorch 안에 구현되어 있으나, 항상 마지막 차원을 정규화한다.
|
601 |
+
그래서 channel을 기준으로 정규화하는 LayerNorm을 따로 구현한다.
|
602 |
+
"""
|
603 |
+
def __init__(self, channels):
|
604 |
+
"""
|
605 |
+
channels: 입력 데이터의 channel 수 | LayerNorm은 channel 차원을 정규화한다.
|
606 |
+
"""
|
607 |
+
super().__init__()
|
608 |
+
self.channels = channels
|
609 |
+
self.eps = 1e-4
|
610 |
+
|
611 |
+
self.gamma = nn.Parameter(torch.ones(channels)) # 학습 가능한 파라미터
|
612 |
+
self.beta = nn.Parameter(torch.zeros(channels)) # 학습 가능한 파라미터
|
613 |
+
|
614 |
+
def forward(self, x):
|
615 |
+
"""
|
616 |
+
=====inputs=====
|
617 |
+
x: (B, channels, *) | 정규화할 입력 데이터
|
618 |
+
=====outputs=====
|
619 |
+
x: (B, channels, *) | channel 차원이 정규화된 데이터
|
620 |
+
"""
|
621 |
+
mean = torch.mean(x, dim=1, keepdim=True) # channel 차원(index=1)의 평균 계산, 차원을 유지한다.
|
622 |
+
variance = torch.mean((x-mean)**2, dim=1, keepdim=True) # 분산 계산
|
623 |
+
|
624 |
+
x = (x - mean) * (variance + self.eps)**(-0.5) # (x - m) / sqrt(v)
|
625 |
+
|
626 |
+
n = len(x.shape)
|
627 |
+
shape = [1] * n
|
628 |
+
shape[1] = -1 # shape = [1, -1, 1] or [1, -1, 1, 1]
|
629 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape) # y = x*gamma + beta
|
630 |
+
|
631 |
+
return x
|
632 |
+
|
633 |
+
class PreNet(nn.Module):
|
634 |
+
"""
|
635 |
+
Encoder의 1번째 모듈
|
636 |
+
"""
|
637 |
+
def __init__(self):
|
638 |
+
super().__init__()
|
639 |
+
self.convs = nn.ModuleList()
|
640 |
+
self.norms = nn.ModuleList()
|
641 |
+
self.relu = nn.ReLU()
|
642 |
+
self.dropout = nn.Dropout(0.5)
|
643 |
+
for i in range(3):
|
644 |
+
self.convs.append(nn.Conv1d(192, 192, kernel_size=5, padding=2)) # (B, 192, T) 유지
|
645 |
+
self.norms.append(LayerNorm(192)) # (B, 192, T) 유지
|
646 |
+
self.linear = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지 | linear 역할을 하는 conv
|
647 |
+
|
648 |
+
def forward(self, x, x_mask):
|
649 |
+
"""
|
650 |
+
=====inputs=====
|
651 |
+
x: (B, 192, T) | Embedding된 입력 데이터
|
652 |
+
x_mask: (B, 1, T) | 글자 길이에 따른 mask (글자가 있으면 True, 없으면 False로 구성)
|
653 |
+
=====outputs=====
|
654 |
+
x: (B, 192, T)
|
655 |
+
"""
|
656 |
+
x0 = x
|
657 |
+
for i in range(3):
|
658 |
+
x = self.convs[i](x * x_mask)
|
659 |
+
x = self.norms[i](x)
|
660 |
+
x = self.relu(x)
|
661 |
+
x = self.dropout(x)
|
662 |
+
x = self.linear(x)
|
663 |
+
x = x0 + x # residual connection
|
664 |
+
return x
|
665 |
+
|
666 |
+
class MultiHeadAttention(nn.Module):
|
667 |
+
"""
|
668 |
+
Encoder 중 2번째 모듈인 TransformerEncoder의 1번째 모듈
|
669 |
+
"""
|
670 |
+
def __init__(self):
|
671 |
+
super().__init__()
|
672 |
+
self.n_heads = 2
|
673 |
+
self.window_size = 4
|
674 |
+
self.k_channels = 192 // self.n_heads # 96
|
675 |
+
|
676 |
+
self.linear_q = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
|
677 |
+
self.linear_k = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
|
678 |
+
self.linear_v = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
|
679 |
+
nn.init.xavier_uniform_(self.linear_q.weight)
|
680 |
+
nn.init.xavier_uniform_(self.linear_k.weight)
|
681 |
+
nn.init.xavier_uniform_(self.linear_v.weight)
|
682 |
+
|
683 |
+
relative_std = self.k_channels ** (-0.5) # 0.1xx
|
684 |
+
self.relative_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96)
|
685 |
+
self.relative_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96)
|
686 |
+
|
687 |
+
self.attention_weights = None
|
688 |
+
self.linear_out = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
|
689 |
+
self.dropout = nn.Dropout(0.1)
|
690 |
+
|
691 |
+
def forward(self, query, context, attention_mask, self_attention=True):
|
692 |
+
"""
|
693 |
+
=====inputs=====
|
694 |
+
query: (B, 192, T_target) | Glow-TTS에서는 self-attention만 이용하므로 query와 context가 동일한 텐서 x이다.
|
695 |
+
context: (B, 192, T_source) | query = context || 여기에서는 특히 T_source = T_target 이다.
|
696 |
+
attention_mask: (B, 1, T, T) | x_mask.unsqueeze(2) * z_mask.unsqueeze(3)
|
697 |
+
self_attention: True/False | self_attention일 때 relative position representations를 적용한다. 여기에서는 항상 True이다.
|
698 |
+
# 실제로는 query와 context에 같은 텐서 x를 입력하면 된다.
|
699 |
+
=====outputs=====
|
700 |
+
output: (B, 192, T)
|
701 |
+
"""
|
702 |
+
|
703 |
+
query = self.linear_q(query)
|
704 |
+
key = self.linear_k(context)
|
705 |
+
value = self.linear_v(context)
|
706 |
+
|
707 |
+
B, _, T_tar = query.size()
|
708 |
+
T_src = key.size(2)
|
709 |
+
query = query.view(B, self.n_heads, self.k_channels, T_tar).transpose(2, 3)
|
710 |
+
key = key.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3)
|
711 |
+
value = value.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3)
|
712 |
+
# (B, 192, T_src) -> (B, 2, 96, T_src) -> (B, 2, T_src, 96)
|
713 |
+
|
714 |
+
scores = torch.matmul(query, key.transpose(2, 3)) / (self.k_channels ** 0.5)
|
715 |
+
# (B, 2, T_tar, 96) * (B, 2, 96, T_src) -> (B, 2, T_tar, T_src)
|
716 |
+
|
717 |
+
if self_attention: # True
|
718 |
+
# Get relative embeddings (relative_keys) (1-1)
|
719 |
+
padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0)
|
720 |
+
start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0)
|
721 |
+
end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4)
|
722 |
+
relative_keys = F.pad(self.relative_k, (0, 0, padding, padding))
|
723 |
+
# (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96)
|
724 |
+
"""
|
725 |
+
위 코드의 F.pad(input, pad) 에서 pad = (0, 0, padding, padding)은 다음을 의미한다.
|
726 |
+
- 앞의 (0, 0): input의 -1차원을 앞으로 0, 뒤로 0만큼 패딩한다.
|
727 |
+
- 앞의 (padding, padding): input의 -2차원을 앞으로 padding, 뒤로 padding만큼 패딩한다.
|
728 |
+
즉, F.pad에서 pad는 역순으로 생각해주어야 한다.
|
729 |
+
"""
|
730 |
+
relative_keys = relative_keys[:, start_pos:end_pos, :] # (1, 2T-1, 96)
|
731 |
+
|
732 |
+
# Matmul with relative keys (2-1)
|
733 |
+
relative_keys = relative_keys.unsqueeze(0).transpose(2, 3) # (1, 2T-1, 96) -> (1, 1, 2T-1, 96) -> (1, 1, 96, 2T-1)
|
734 |
+
x = torch.matmul(query, relative_keys) # (B, 2, T_tar, 96) * (1, 1, 96, 2T_src-1) = (B, 2, T, 2T-1)
|
735 |
+
# self attention에서는 T_tar = T_src이므로 이를 다르게 고려할 필요가 없다.
|
736 |
+
|
737 |
+
# Relative position to absolute position (3-1)
|
738 |
+
T = T_tar # Absolute position to relative position에서도 쓰임.
|
739 |
+
x = F.pad(x, (0, 1)) # (B, 2, T, 2*T-1) -> (B, 2, T, 2*T)
|
740 |
+
x = x.view(B, self.n_heads, T * 2 * T) # (B, 2, T, 2*T) -> (B, 2. 2T^2)
|
741 |
+
x = F.pad(x, (0, T-1)) # (B, 2, 2T^2 + T - 1)
|
742 |
+
x = x.view(B, self.n_heads, T+1, 2*T-1) # (B, 2, T+1, 2T-1)
|
743 |
+
relative_logits = x[:, :, :T, T-1:] # (B, 2, T, T)
|
744 |
+
|
745 |
+
# Compute scores
|
746 |
+
scores_local = relative_logits / (self.k_channels ** 0.5)
|
747 |
+
scores = scores + scores_local # (B, 2, T, T)
|
748 |
+
"""
|
749 |
+
위 식은 Self-Attention with Relative Position Representations 논문의 5번 식을 구현한 것이다.
|
750 |
+
Relative- 논문: https://arxiv.org/pdf/1803.02155.pdf
|
751 |
+
"""
|
752 |
+
|
753 |
+
scores = scores.masked_fill(attention_mask == 0, -1e-4) # attention_mask가 0인 곳을 -1e-4로 채운다.
|
754 |
+
|
755 |
+
attention_weights = F.softmax(scores, dim=-1) # (B, 2, T_tar, T_src) # Relative- 논문에서의 alpha에 해당한다.
|
756 |
+
attention_weights = self.dropout(attention_weights) # dropout하는 이유가 무엇일까?
|
757 |
+
output = torch.matmul(attention_weights, value) # (B, 2, T_tar, T_src) * (B, 2, T_src, 96) -> (B, 2, T_tar, 96)
|
758 |
+
|
759 |
+
if self_attention: # True
|
760 |
+
# Absolute position to relative position (3-2)
|
761 |
+
x = F.pad(attention_weights, (0, T-1)) # (B, 2, T, T) -> (B, 2, T, 2T-1)
|
762 |
+
x = x.view((B, self.n_heads, T * (2*T-1))) # (B, 2, 2T^2-T)
|
763 |
+
x = F.pad(x, (T, 0)) # (B, 2, 2T^2) # 앞에 패딩
|
764 |
+
x = x.view((B, self.n_heads, T, 2*T)) # (B, 2, T, 2T)
|
765 |
+
relative_weights = x[:, :, :, 1:] # (B, 2, T, 2T-1)
|
766 |
+
|
767 |
+
# Get relative embeddings (relative_value) (1-2) # (1-1)과 거의 동일
|
768 |
+
padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0)
|
769 |
+
start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0)
|
770 |
+
end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4)
|
771 |
+
relative_values = F.pad(self.relative_v, (0, 0, padding, padding))
|
772 |
+
# (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96)
|
773 |
+
relative_values = relative_values[:, start_pos:end_pos, :] # (1, 2T-1, 96)
|
774 |
+
|
775 |
+
# Matmul with relative values (2-2)
|
776 |
+
relative_values = relative_values.unsqueeze(0) # (1, 1, 2T-1, 96)
|
777 |
+
|
778 |
+
output = output + torch.matmul(relative_weights, relative_values)
|
779 |
+
# (B, 2, T, 2T-1) * (1, 1, 2T-1, 96) = (B, 2, T, 96)
|
780 |
+
"""
|
781 |
+
위 식은 Self-Attention with Relative Position Representations 논문의 3번 식을 구현한 것이다. (분배법칙 이용)
|
782 |
+
Relative- 논문: https://arxiv.org/pdf/1803.02155.pdf
|
783 |
+
"""
|
784 |
+
|
785 |
+
output = output.transpose(2, 3).contiguous().view(B, 192, T_tar)
|
786 |
+
# (B, 2, 96, T) -> 메모리에 연속 배치 -> (B, 192, T)
|
787 |
+
|
788 |
+
self.attention_weights = attention_weights # (B, 2, T, T)
|
789 |
+
output = self.linear_out(output)
|
790 |
+
return output # (B, 192, T)
|
791 |
+
|
792 |
+
class FFN(nn.Module):
|
793 |
+
"""
|
794 |
+
Encoder 중 2번째 모듈인 TransformerEncoder의 2번째 모듈
|
795 |
+
"""
|
796 |
+
def __init__(self):
|
797 |
+
super().__init__()
|
798 |
+
self.conv1 = nn.Conv1d(192, 768, kernel_size=3, padding=1) # (B, 192, T) -> (B, 768, T)
|
799 |
+
self.relu = nn.ReLU()
|
800 |
+
self.conv2 = nn.Conv1d(768, 192, kernel_size=3, padding=1) # (B, 768, T) -> (B, 192, T)
|
801 |
+
self.dropout = nn.Dropout(0.1)
|
802 |
+
|
803 |
+
def forward(self, x, x_mask):
|
804 |
+
"""
|
805 |
+
=====inputs=====
|
806 |
+
x: (B, 192, T)
|
807 |
+
x_mask: (B, 1, T)
|
808 |
+
=====outputs=====
|
809 |
+
output: (B, 192, T)
|
810 |
+
"""
|
811 |
+
x = self.conv1(x)
|
812 |
+
x = self.relu(x)
|
813 |
+
x = self.dropout(x)
|
814 |
+
x = self.conv2(x)
|
815 |
+
output = x * x_mask
|
816 |
+
return output
|
817 |
+
|
818 |
+
class TransformerEncoder(nn.Module):
|
819 |
+
"""
|
820 |
+
Encoder의 2번째 모듈
|
821 |
+
"""
|
822 |
+
def __init__(self):
|
823 |
+
super().__init__()
|
824 |
+
self.attentions = nn.ModuleList()
|
825 |
+
self.norms1 = nn.ModuleList()
|
826 |
+
self.ffns = nn.ModuleList()
|
827 |
+
self.norms2 = nn.ModuleList()
|
828 |
+
for i in range(6):
|
829 |
+
self.attentions.append(MultiHeadAttention())
|
830 |
+
self.norms1.append(LayerNorm(192))
|
831 |
+
self.ffns.append(FFN())
|
832 |
+
self.norms2.append(LayerNorm(192))
|
833 |
+
self.dropout = nn.Dropout(0.1)
|
834 |
+
|
835 |
+
def forward(self, x, x_mask):
|
836 |
+
"""
|
837 |
+
=====inputs=====
|
838 |
+
x: (B, 192, T)
|
839 |
+
x_mask: (B, 1, T)
|
840 |
+
=====outputs=====
|
841 |
+
output: (B, 192, T)
|
842 |
+
"""
|
843 |
+
attention_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(3)
|
844 |
+
# (B, 1, 1, T) * (B, 1, T, 1) = (B, 1, T, T), only consist 0 or 1
|
845 |
+
for i in range(6):
|
846 |
+
x = x * x_mask
|
847 |
+
y = self.attentions[i](x, x, attention_mask)
|
848 |
+
y = self.dropout(y)
|
849 |
+
x = x + y # residual connection
|
850 |
+
x = self.norms1[i](x) # (B, 192, T) 유지
|
851 |
+
|
852 |
+
y = self.ffns[i](x, x_mask)
|
853 |
+
y = self.dropout(y)
|
854 |
+
x = x + y # residual connection
|
855 |
+
x = self.norms2[i](x)
|
856 |
+
output = x * x_mask
|
857 |
+
return output # (B, 192, T)
|
858 |
+
|
859 |
+
class DurationPredictor(nn.Module):
|
860 |
+
"""
|
861 |
+
Encoder의 3번째 모듈
|
862 |
+
"""
|
863 |
+
def __init__(self):
|
864 |
+
super().__init__()
|
865 |
+
self.conv1 = nn.Conv1d(192, 256, kernel_size=3, padding=1) # (B, 192, T) -> (B, 256, T)
|
866 |
+
self.norm1 = LayerNorm(256)
|
867 |
+
self.conv2 = nn.Conv1d(256, 256, kernel_size=3, padding=1) # (B, 256, T) -> (B, 256, T)
|
868 |
+
self.norm2 = LayerNorm(256)
|
869 |
+
self.linear = nn.Conv1d(256, 1, kernel_size=1) # (B, 256, T) -> (B, 1, T)
|
870 |
+
|
871 |
+
self.relu = nn.ReLU()
|
872 |
+
self.dropout = nn.Dropout(0.1)
|
873 |
+
|
874 |
+
def forward(self, x, x_mask):
|
875 |
+
"""
|
876 |
+
=====inputs=====
|
877 |
+
x: (B, 192, T)
|
878 |
+
x_mask: (B, 1, T)
|
879 |
+
=====outputs=====
|
880 |
+
output: (B, 1, T)
|
881 |
+
"""
|
882 |
+
x = self.conv1(x * x_mask) # (B, 192, T) -> (B, 256, T)
|
883 |
+
x = self.relu(x)
|
884 |
+
x = self.norm1(x)
|
885 |
+
x = self.dropout(x)
|
886 |
+
|
887 |
+
x = self.conv2(x * x_mask) # (B, 256, T) -> (B, 256, T)
|
888 |
+
x = self.relu(x)
|
889 |
+
x = self.norm2(x)
|
890 |
+
x = self.dropout(x)
|
891 |
+
|
892 |
+
x = self.linear(x * x_mask) # (B, 256, T) -> (B, 1, T)
|
893 |
+
output = x * x_mask
|
894 |
+
return output
|