marigold334 commited on
Commit
1a91ed2
1 Parent(s): b164712

Create Tmodel.py (#26)

Browse files

- Create Tmodel.py (095470eacaa34923b2e728093e4f02989fa057b3)

Files changed (1) hide show
  1. Tmodel.py +894 -0
Tmodel.py ADDED
@@ -0,0 +1,894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import weight_norm
5
+ import math
6
+ import torch
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ symbol_length = 73
10
+
11
+ class GlowTTS(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.encoder = Encoder()
15
+ self.decoder = Decoder()
16
+
17
+ def forward(self, text, text_len, mel=None, mel_len=None, inference=False):
18
+ """
19
+ =====inputs=====
20
+ text: (B, T)
21
+ text_len: (B) list
22
+ mel: (B, 80, F)
23
+ mel_len: (B) list
24
+ inference: True/False
25
+ =====outputs=====
26
+ (tuple) (z, z_mean, z_log_std, log_det, z_mask)
27
+ z(training) or y(inference): (B, 80, F) | z: latent representation, y: mel-spectrogram
28
+ z_mean: (B, 80, F)
29
+ z_log_std: (B, 80, F)
30
+ log_det: (B) or None
31
+ z_mask: (B, 1, F)
32
+ (tuple) (x_mean, x_log_std, x_mask)
33
+ x_mean: (B, 80, T)
34
+ x_log_std: (B, 80, T)
35
+ x_mask: (B, 1, T)
36
+ (tuple) (attention_alignment, x_log_dur, log_d)
37
+ attention_alignment: (B, T, F)
38
+ x_log_dur: (B, 1, T) | 추측한 duration의 log scale
39
+ log_d: (B, 1, T) | 적절하다고 추측한 alignment에서의 duration의 log scale
40
+ """
41
+ x_mean, x_log_std, x_log_dur, x_mask = self.encoder(text, text_len)
42
+ # x_std, x_dur 에 log를 붙인 이유는, 논문 저자의 구현에서는 log가 취해진 값으로 간주하기 때문이다.
43
+ y, y_len = mel, mel_len
44
+
45
+ if not inference: # training
46
+ y_max_len = y.size(2)
47
+ else: # inference
48
+ dur = torch.exp(x_log_dur) * x_mask # (B, 1, T)
49
+ ceil_dur = torch.ceil(dur) # (B, 1, T)
50
+ y_len = torch.clamp_min(torch.sum(ceil_dur, [1, 2]), 1).long() # (B)
51
+ # ceil_dur을 [1, 2] 축에 대해 sum한 뒤 최솟값이 1이상이 되도록 설정. 정수 long 타입으로 반환한다.
52
+ y_max_len = None
53
+
54
+ # preprocessing
55
+ if y_max_len is not None:
56
+ y_max_len = (y_max_len // 2) * 2 # 홀수면 1을 빼서 짝수로 만든다.
57
+ y = y[:, :, :y_max_len] # y_max_len에 맞게 y를 조정
58
+ y_len = (y_len // 2) * 2 # y_len이 홀수이면 1을 빼서 짝수로 만든다.
59
+
60
+ # make the z_mask
61
+ B = len(y_len)
62
+ temp_max = max(y_len)
63
+ z_mask = torch.zeros((B, 1, temp_max), dtype=torch.bool).to(device) # (B, 1, F)
64
+ for idx, length in enumerate(y_len):
65
+ z_mask[idx, :, :length] = True
66
+
67
+ # make the attention_mask
68
+ attention_mask = x_mask.unsqueeze(3) * z_mask.unsqueeze(2) # (B, 1, T, 1) * (B, 1, 1, F) = (B, 1, T, F)
69
+ # 주의: Encoder의 attention_mask와는 다른 mask임.
70
+
71
+ if not inference: # training
72
+ z, log_det = self.decoder(y, z_mask, reverse=False)
73
+ with torch.no_grad():
74
+ x_std_squared_root = torch.exp(-2 * x_log_std) # (B, 80, T)
75
+ logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_log_std, [1]).unsqueeze(-1) # [(B, T, F)
76
+ logp2 = torch.matmul(x_std_squared_root.transpose(1, 2), -0.5 * (z ** 2)) # [(B, T, 80) * (B, 80, F) = (B, T, F)
77
+ logp3 = torch.matmul((x_mean * x_std_squared_root).transpose(1,2), z) # (B, T, 80) * (B, 80, F) = (B, T, F)
78
+ logp4 = torch.sum(-0.5 * (x_mean ** 2) * x_std_squared_root, [1]).unsqueeze(-1) # (B, T, F)
79
+ logp = logp1 + logp2 + logp3 + logp4 # (B, T, F)
80
+ """
81
+ logp는 normal distribution N(x_mean, x_std)의 maximum log-likelihood이다.
82
+ sum(log(N(z;x_mean, x_std)))를 정규분포 식을 이용하여 분배법칙으로 풀어내면 위와 같은 식이 도출된다.
83
+ """
84
+ attention_alignment = maximum_path(logp, attention_mask.squeeze(1)).detach() # alignment (B, T, F)
85
+
86
+ z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
87
+ z_mean = z_mean.transpose(1, 2) # (B, 80, F)
88
+ z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
89
+ z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
90
+ log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment에서 형성된 duration의 log scale
91
+ return (z, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
92
+
93
+ else: # inference
94
+ # generate_path (make attention_alignment using ceil(x_dur))
95
+ attention_alignment = generate_path(ceil_dur.squeeze(1), attention_mask.squeeze(1)) # (B, T, F)
96
+ z_mean = torch.matmul(attention_alignment.transpose(1, 2), x_mean.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
97
+ z_mean = z_mean.transpose(1, 2) # (B, 80, F)
98
+ z_log_std = torch.matmul(attention_alignment.transpose(1, 2), x_log_std.transpose(1, 2)) # (B, F, T) * (B, T, 80) -> (B, F, 80)
99
+ z_log_std = z_log_std.transpose(1, 2) # (B, 80, F)
100
+ log_d = torch.log(1e-8 + torch.sum(attention_alignment, -1)).unsqueeze(1) * x_mask # (B, 1, T) | alignment에서 형성된 duration의 log scale
101
+
102
+ z = (z_mean + torch.exp(z_log_std) * torch.randn_like(z_mean)) * z_mask # z(latent representation) 생성
103
+ y, log_det = self.decoder(z, z_mask, reverse=True) # mel-spectrogram 생성
104
+ return (y, z_mean, z_log_std, log_det, z_mask), (x_mean, x_log_std, x_mask), (attention_alignment, x_log_dur, log_d)
105
+
106
+ ##### 아래 논문의 구현이 훨씬 빠르다. 이 논문 구현을 보고 위의 구현을 변경할 필요가 있다. #####
107
+ def maximum_path(value, mask, max_neg_val=-np.inf):
108
+ """ Numpy-friendly version. It's about 4 times faster than torch version.
109
+ value: [b, t_x, t_y]
110
+ mask: [b, t_x, t_y]
111
+ """
112
+ value = value * mask
113
+
114
+ device = value.device
115
+ dtype = value.dtype
116
+ value = value.cpu().detach().numpy()
117
+ mask = mask.cpu().detach().numpy().astype(bool)
118
+
119
+ b, t_x, t_y = value.shape
120
+ direction = np.zeros(value.shape, dtype=np.int64)
121
+ v = np.zeros((b, t_x), dtype=np.float32)
122
+ x_range = np.arange(t_x, dtype=np.float32).reshape(1,-1)
123
+ for j in range(t_y):
124
+ v0 = np.pad(v, [[0,0],[1,0]], mode="constant", constant_values=max_neg_val)[:, :-1]
125
+ v1 = v
126
+ max_mask = (v1 >= v0)
127
+ v_max = np.where(max_mask, v1, v0)
128
+ direction[:, :, j] = max_mask
129
+
130
+ index_mask = (x_range <= j)
131
+ v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
132
+ direction = np.where(mask, direction, 1)
133
+
134
+ path = np.zeros(value.shape, dtype=np.float32)
135
+ index = mask[:, :, 0].sum(1).astype(np.int64) - 1
136
+ index_range = np.arange(b)
137
+ for j in reversed(range(t_y)):
138
+ path[index_range, index, j] = 1
139
+ index = index + direction[index_range, index, j] - 1
140
+ path = path * mask.astype(np.float32)
141
+ path = torch.from_numpy(path).to(device=device, dtype=dtype)
142
+ return path
143
+
144
+
145
+ def generate_path(duration, mask):
146
+ """
147
+ duration: [b, t_x]
148
+ mask: [b, t_x, t_y]
149
+ """
150
+ device = duration.device
151
+
152
+ b, t_x, t_y = mask.shape # (B, T, F)
153
+ cum_duration = torch.cumsum(duration, 1) # 누적합, (B, T)
154
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) # (B, T, F)
155
+
156
+ cum_duration_flat = cum_duration.view(b * t_x) # (B*T)
157
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) # (B*T, F)
158
+ path = path.view(b, t_x, t_y) # (B, T, F)
159
+ path = path.to(torch.float32)
160
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:,:-1] # (B, T, F) # T의 차원 맨 앞을 -1한다.
161
+ path = path * mask
162
+ return path
163
+
164
+ def sequence_mask(length, max_length=None):
165
+ if max_length is None:
166
+ max_length = length.max()
167
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
168
+ return x.unsqueeze(0) < length.unsqueeze(1)
169
+
170
+ def convert_pad_shape(pad_shape):
171
+ l = pad_shape[::-1] # [[0, 0], [p, p], [0, 0]]
172
+ pad_shape = [item for sublist in l for item in sublist] # [0, 0, p, p, 0, 0]
173
+ return pad_shape
174
+
175
+ def MAS(path, logp, T_max, F_max):
176
+ """
177
+ Glow-TTS의 모듈인 maximum_path의 모듈
178
+ MAS 알고리즘을 수행하는 함수이다.
179
+ =====inputs=====
180
+ path: (T, F)
181
+ logp: (T, F)
182
+ T_max: (1)
183
+ F_max: (1)
184
+ =====outputs=====
185
+ path: (T, F) | 0과 1로 구성된 alignment
186
+ """
187
+ neg_inf = -1e9 # negative infinity
188
+ # forward
189
+ for j in range(F_max):
190
+ for i in range(max(0, T_max + j - F_max), min(T_max, j + 1)): # 평행사변형을 생각하라.
191
+ # Q_i_j-1 (current)
192
+ if i == j:
193
+ Q_cur = neg_inf
194
+ else:
195
+ Q_cur = logp[i, j-1] # j=0이면 i도 0이므로 j-1을 사용해도 된다.
196
+
197
+ # Q_i-1_j-1 (previous)
198
+ if i==0:
199
+ if j==0:
200
+ Q_prev = 0. # i=0, j=0인 경우에는 logp 값만 반영해야 한다.
201
+ else:
202
+ Q_prev = neg_inf # i=0인 경우에는 Q_i-1_j-1을 반영하지 않아야 한다.
203
+ else:
204
+ Q_prev = logp[i-1, j-1]
205
+
206
+ # logp에 Q를 갱신한다.
207
+ logp[i, j] = max(Q_cur, Q_prev) + logp[i, j]
208
+
209
+ # backtracking
210
+ idx = T_max - 1
211
+ for j in range(F_max-1, -1, -1): # F_max-1부터 -1까지(-1 포함 없이 0까지) -1씩 감소
212
+ path[idx, j] = 1
213
+ if idx != 0:
214
+ if (logp[idx, j-1] < logp[idx-1, j-1]) or (idx == j):
215
+ idx -= 1
216
+
217
+ return path
218
+
219
+
220
+ def maximum_path(logp, attention_mask):
221
+ """
222
+ Glow-TTS에 사용되�� 모듈
223
+ MAS를 사용하여 alignment를 찾아주는 역할을 한다.
224
+ 논문 저자 구현에서는 cpython을 이용하여 병렬 처리를 구현한 듯 하나
225
+ 여기에서는 python만을 이용하여 구현하였다.
226
+ =====inputs=====
227
+ logp: (B, T, F) | N(x_mean, x_std)의 log-likelihood
228
+ attention_mask: (B, T, F)
229
+ =====outputs=====
230
+ path: (B, T, F) | alignment
231
+ """
232
+ B = logp.shape[0]
233
+
234
+ logp = logp * attention_mask
235
+ # 계산은 CPU에서 실행되도록 하기 위해 기존의 device를 저장하고 .cpu().numpy()를 한다.
236
+ logp_device = logp.device
237
+ logp_type = logp.dtype
238
+ logp = logp.data.cpu().numpy().astype(np.float32)
239
+ attention_mask = attention_mask.data.cpu().numpy()
240
+
241
+ path = np.zeros_like(logp).astype(np.int32) # (B, T, F)
242
+ T_max = attention_mask.sum(1)[:, 0].astype(np.int32) # (B)
243
+ F_max = attention_mask.sum(2)[:, 0].astype(np.int32) # (B)
244
+
245
+ # MAS 알고리즘
246
+ for idx in range(B):
247
+ path[idx] = MAS(path[idx], logp[idx], T_max[idx], F_max[idx]) # (T, F)
248
+ return torch.from_numpy(path).to(device=logp_device, dtype=logp_type)
249
+
250
+ def generate_path(ceil_dur, attention_mask):
251
+ """
252
+ Glow-TTS에 사용되는 모듈
253
+ inference 과정에서 alignment를 만들어낸다.
254
+ =====input=====
255
+ ceil_dur: (B, T) | 추론한 duration에 ceil 연산한 것 | ex) [[2, 1, 2, 2, ...], [1, 2, 1, 3, ...], ...]
256
+ attention_mask: (B, T, F)
257
+ =====output=====
258
+ path: (B, T, F) | alignment
259
+ """
260
+ B, T, Frame = attention_mask.shape
261
+ cum_dur = torch.cumsum(ceil_dur, 1)
262
+ cum_dur = cum_dur.to(torch.int32) # (B, T) | 누적합 | ex) [[2, 3, 5, 7, ...], [1, 3, 4, 7, ...], ...]
263
+ path = torch.zeros(B, T, Frame).to(ceil_dur.device) # (B, T, F) | all False(0)
264
+
265
+ # make the sequence_mask
266
+ for b, batch_cum_dur in enumerate(cum_dur):
267
+ for t, each_cum_dur in enumerate(batch_cum_dur):
268
+ path[b, t, :each_cum_dur] = torch.ones((1, 1, each_cum_dur)).to(ceil_dur.device)
269
+ # cum_dur로부터 True(1)를 path에 새겨넣는다.
270
+ path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1] # (B, T, F)
271
+ """
272
+ ex) batch를 잠시 제외해두고 예시를 든다.
273
+ [[1, 1, 0, 0, 0, 0, 0], [[0, 0, 0, 0, 0, 0, 0], [[1, 1, 0, 0, 0, 0, 0],
274
+ [1, 1, 1, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0], = [0, 0, 1, 0, 0, 0, 0],
275
+ [1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 0, 0, 0, 0], [0, 0, 0, 1, 1, 0, 0],
276
+ [1, 1, 1, 1, 1, 1, 1]] [1, 1, 1, 1, 1, 0, 0]] [0, 0, 0, 0, 0, 1, 1]]
277
+ """
278
+ path = path * attention_mask
279
+ return path
280
+
281
+ class Decoder(nn.Module):
282
+ def __init__(self):
283
+ super().__init__()
284
+ self.flows = nn.ModuleList()
285
+ for i in range(12):
286
+ self.flows.append(ActNorm())
287
+ self.flows.append(InvertibleConv())
288
+ self.flows.append(AffineCouplingLayer())
289
+
290
+ def forward(self, x, x_mask, reverse=False):
291
+ """
292
+ =====inputs=====
293
+ x: (B, 80, F) | mel-spectrogram(Direct) OR latent representation(Reverse)
294
+ x_mask: (B, 1, F)
295
+ =====outputs=====
296
+ z: (B, 80, F) | latent representation(Direct) OR mel-spectrogram(Reverse)
297
+ total_log_det: (B) or None | log determinant
298
+ """
299
+ if not reverse:
300
+ flows = self.flows
301
+ total_log_det = 0
302
+ else:
303
+ flows = reversed(self.flows)
304
+ total_log_det = None
305
+
306
+ x, x_mask = Squeeze(x, x_mask) # (B, 80, F) -> (B, 160, F//2) | (B, 1, F) -> (B, 1, F//2)
307
+
308
+ for f in flows:
309
+ if not reverse:
310
+ x, log_det = f(x, x_mask, reverse=reverse)
311
+ total_log_det += log_det
312
+ else:
313
+ x, _ = f(x, x_mask, reverse=reverse)
314
+
315
+ x, x_mask = Unsqueeze(x, x_mask) # (B, 160, F//2) -> (B, 80, F) | (B, 1, F//2) -> (B, 1, F)
316
+
317
+ return x, total_log_det
318
+
319
+ """
320
+ Decoder는 Glow: Generative Flow with Invertible 1×1 Convolutions 논문의 기본 구조를 따라간다.
321
+ Glow 논문: https://arxiv.org/pdf/1807.03039.pdf
322
+ """
323
+ def Squeeze(x, x_mask):
324
+ """
325
+ Decoder의 preprocessing
326
+ =====inputs=====
327
+ x: (B, 80, F) | mel_spectrogram or latent representation
328
+ x_mask: (B, 1, F)
329
+ =====outputs=====
330
+ x: (B, 160, F//2) | F//2 = [F/2] ([]: 가우스 기호)
331
+ x_mask: (B, 160, F//2)
332
+ """
333
+ B, C, F = x.size()
334
+ x = x[:, :, :(F//2)*2] # F가 홀수이면 맨 뒤 한 frame을 버림.
335
+ x = x.view(B, C, F//2, 2) # (B, 80, F//2, 2)
336
+ x = x.permute(0, 3, 1, 2).contiguous() # (B, 2, 80, F//2)
337
+ x = x.view(B, C*2, F//2) # (B, 160, F//2)
338
+
339
+ x_mask = x_mask[:, :, 1::2] # (B, 1, F//2) frame을 1부터 한칸씩 건너뛴다.
340
+ x = x * x_mask # masking
341
+ return x, x_mask
342
+
343
+ class ActNorm(nn.Module):
344
+ """
345
+ Decoder의 1번째 모듈
346
+ """
347
+ def __init__(self):
348
+ super().__init__()
349
+ self.log_s = nn.Parameter(torch.zeros(1, 160, 1)) # Glow 논문의 s에서 log를 취한 것이다. 즉, log[s]
350
+ self.bias = nn.Parameter(torch.zeros(1, 160, 1))
351
+
352
+ def forward(self, x, x_mask, reverse=False):
353
+ """
354
+ =====inputs=====
355
+ x: (B, 160, F//2) | mel_spectrogram features
356
+ x_mask: (B, 1, F//2) | mel_spectrogram features의 mask. (Decoder의 Squeeze에서 변형됨.)
357
+ =====outputs=====
358
+ z: (B, 160, F//2)
359
+ log_det: (B) or None | log_determinant, reverse=True이면 None 반환
360
+ """
361
+ x_len = torch.sum(x_mask, [1, 2]) # (B) | 1, 2차원의 값을 더한다. cf. [1, 2] 대신 [2]만 사용하면 shape가 (B, 1)이 된다.
362
+
363
+ if not reverse:
364
+ z = (x * torch.exp(self.log_s) + self.bias) * x_mask # function & masking
365
+ log_det = x_len * torch.sum(self.log_s) # log_determinant
366
+ # Glow 논문의 Table 1을 확인하라. log_s를 log[s]라 볼 수 있다.
367
+ # determinant 대신 log_determinant를 사용하는 이유는 det보다 작은 수치와 적은 계산량 때문으로 추측된다.
368
+ else:
369
+ z = ((x - self.bias) / torch.exp(self.log_s)) * x_mask # inverse function & masking
370
+ log_det = None
371
+
372
+ return z, log_det
373
+
374
+ class InvertibleConv(nn.Module):
375
+ """
376
+ Decoder의 2번째 모듈
377
+ """
378
+ def __init__(self):
379
+ super().__init__()
380
+ Q = torch.linalg.qr(torch.FloatTensor(4, 4).normal_())[0] # (4, 4)
381
+ """
382
+ torch.FloatTensor(4, 4).normal_(): 정규분포 N(0, 1)에서 무작위로 추출한 4x4 matrix
383
+ Q, R = torch.linalg.qr(W): QR분해 | Q: 직교 행렬, R: upper traiangular 행렬 cf. det(Q) = 1 or -1
384
+ """
385
+ if torch.det(Q) < 0:
386
+ Q[:, 0] = -1 * Q[:, 0] # 0번째 열의 부호를 바꿔서 det(Q) = -1로 만든다.
387
+ self.W = nn.Parameter(Q)
388
+
389
+ def forward(self, x, x_mask, reverse=False):
390
+ """
391
+ =====inputs=====
392
+ x: (B, 160, F//2)
393
+ x_mask: (B, 1, F//2)
394
+ =====outputs=====
395
+ z: (B, 160, F//2)
396
+ log_det: (B) or None
397
+ """
398
+ B, C, f = x.size() # B, 160, F//2
399
+ x_len = torch.sum(x_mask, [1, 2]) # (B)
400
+
401
+ # channel mixing
402
+ x = x.view(B, 2, C//4, 2, f) # (B, 2, 40, 2, F//2)
403
+ x = x.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 2, 40, F//2)
404
+ x = x.view(B, 4, C//4, f) # (B, 4, 40, F//2)
405
+
406
+ # 편의상 log_det부터 구한다.
407
+ if not reverse:
408
+ weight = self.W
409
+ log_det = (C/4) * x_len * torch.logdet(self.W) # (B) | torch.logdet(W): log(det(W))
410
+ # height = C/4, width = x_len 인 상황임을 고려하면 Glow 논문의 log_determinant 식과 같다.
411
+ else:
412
+ weight = torch.linalg.inv(self.W) # inverse matrix
413
+ log_det = None
414
+
415
+ weight = weight.view(4, 4, 1, 1)
416
+ z = F.conv2d(x, weight) # (B, 4, 40, F//2) * (4, 4, 1, 1) -> (B, 4, 40, F//2)
417
+ """
418
+ F.conv2d(x, weight)의 convolution 연산은 다음과 같이 생각해야 한다.
419
+ (B, 4, 40, F//2): (batch_size, in_channels, height, width)
420
+ (4, 4, 1, 1): (out_channels, in_channels/groups, kernel_height, kernel_width)
421
+
422
+ 즉, nn.Conv2d(4, 4, kernel_size=(1, 1))인 상황에 가중치를 준 것이다.
423
+ """
424
+
425
+ # channel unmixing
426
+ z = z.view(B, 2, 2, C//4, f) # (B, 4, 40, F//2) -> (B, 2, 2, 40, F//2)
427
+ z = z.permute(0, 1, 3, 2, 4).contiguous() # (B, 2, 40, 2, F//2)
428
+ z = z.view(B, C, f) * x_mask # (B, 160, F//2) & masking
429
+ return z, log_det
430
+
431
+ class WN(nn.Module):
432
+ """
433
+ Decoder의 3번째 모듈인 AffineCouplingLayer의 모듈
434
+
435
+ 해당 구조는 WAVEGLOW: A FLOW-BASED GENERATIVE NETWORK FOR SPEECH SYNTHESIS 로부터 제안되었다.
436
+ WaveGlow 논문: https://arxiv.org/pdf/1811.00002.pdf
437
+ """
438
+ def __init__(self, dilation_rate=1):
439
+ super().__init__()
440
+ self.in_layers = nn.ModuleList()
441
+ self.res_skip_layers = nn.ModuleList()
442
+
443
+ for i in range(4):
444
+ dilation = dilation_rate ** i # NVIDIA WaveGlow에서는 dilation_rate=2이지만, 여기에서는 1이므로 의미는 없다.
445
+ in_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=5, dilation=dilation,
446
+ padding=((5-1) * dilation)//2)) # (B, 192, F//2) -> (B, 2*192, F//2)
447
+ self.in_layers.append(in_layer)
448
+
449
+ if i < 3:
450
+ res_skip_layer = weight_norm(nn.Conv1d(192, 2*192, kernel_size=1)) # (B, 192, F//2) -> (B, 2*192, F//2)
451
+ else:
452
+ res_skip_layer = weight_norm(nn.Conv1d(192, 192, kernel_size=1)) # (B, 192, F//2) -> (B, 192, F//2)
453
+ self.res_skip_layers.append(res_skip_layer)
454
+
455
+ self.dropout = nn.Dropout(0.05)
456
+
457
+ def forward(self, x, x_mask):
458
+ """
459
+ =====inputs=====
460
+ x: (B, 192, F//2)
461
+ x_mask: (B, 1, F//2)
462
+ =====outputs=====
463
+ output: (B, 192, F//2)
464
+ """
465
+ output = torch.zeros_like(x) # (B, 192, F//2) all zeros
466
+
467
+ for i in range(4):
468
+ x_in = self.in_layers[i](x) # (B, 192, F//2) -> (B, 2*192, F//2)
469
+ x_in = self.dropout(x_in) # dropout
470
+
471
+ # fused add tanh sigmoid multiply
472
+ tanh_act = torch.tanh(x_in[:, :192, :]) # (B, 192, F//2)
473
+ sigmoid_act = torch.sigmoid(x_in[:, 192:, :]) # (B, 192, F//2)
474
+
475
+ acts = sigmoid_act * tanh_act # (B, 192, F//2)
476
+
477
+ x_out = self.res_skip_layers[i](acts) # (B, 192, F//2) -> (B, 2*192, F//2) or [last](B, 192, F//2)
478
+ if i < 3:
479
+ x = (x + x_out[:, :192, :]) * x_mask # residual connection & masking
480
+ output += x_out[:, 192:, :] # add output
481
+ else:
482
+ output += x_out # (B, 192, F//2)
483
+
484
+ output = output * x_mask # masking
485
+ return output
486
+
487
+ class AffineCouplingLayer(nn.Module):
488
+ """
489
+ Decoder의 3번째 모듈
490
+ """
491
+ def __init__(self):
492
+ super().__init__()
493
+ self.start_conv = weight_norm(nn.Conv1d(160//2, 192, kernel_size=1)) # (B, 80, F//2) -> (B, 192, F//2)
494
+ self.wn = WN()
495
+ self.end_conv = nn.Conv1d(192, 160, kernel_size=1) # (B, 192, F//2) -> (B, 160, F//2)
496
+ # end_conv의 초기 가중치를 0으로 설정하는 것이 처음에 학습하지 않는 역할을 하며, 이는 학습 안정화에 도움이 된다.
497
+ self.end_conv.weight.data.zero_() # weight를 0으로 초기화
498
+ self.end_conv.bias.data.zero_() # bias를 0으로 초기화
499
+
500
+ def forward(self, x, x_mask, reverse=False):
501
+ """
502
+ =====inputs=====
503
+ x: (B, 160, F//2)
504
+ x_mask: (B, 1, F//2)
505
+ =====outputs=====
506
+ z: (B, 160, F//2)
507
+ log_det: (B) or None
508
+ """
509
+ B, C, f = x.size() # B, 160, F//2
510
+ x_0, x_1 = x[:, :C//2, :], x[:, C//2:, :] # split: (B, 80, F//2) x2
511
+
512
+ x = self.start_conv(x_0) * x_mask # (B, 80, F//2) -> (B, 192, F//2) & masking
513
+ x = self.wn(x, x_mask) # (B, 192, F//2)
514
+ out = self.end_conv(x) # (B, 192, F//2) -> (B, 160, F//2)
515
+
516
+ z_0 = x_0 # (B, 80, F//2)
517
+ m = out[:, :C//2, :] # (B, 80, F//2)
518
+ log_s = out[:, C//2:, :] # (B, 80, F//2)
519
+
520
+ if not reverse:
521
+ z_1 = (torch.exp(log_s) * x_1 + m) * x_mask # (B, 80, F//2) | function & masking
522
+ log_det = torch.sum(log_s * x_mask, [1, 2]) # (B)
523
+ else:
524
+ z_1 = (x_1 - m) / torch.exp(log_s) * x_mask # (B, 80, F//2) | inverse function & masking
525
+ log_det = None
526
+
527
+ z = torch.cat([z_0, z_1], dim=1) # (B, 160, F//2)
528
+ return z, log_det
529
+
530
+ def Unsqueeze(x, x_mask):
531
+ """
532
+ Decoder의 postprocessing
533
+ =====inputs=====
534
+ x: (B, 160, F//2)
535
+ x_mask: (B, 1, F//2)
536
+ =====outputs=====
537
+ x: (B, 80, F)
538
+ x_mask: (B, 1, F)
539
+ """
540
+ B, C, f = x.size() # B, 160, F//2
541
+ x = x.view(B, 2, C//2, f) # (B, 2, 80, F//2)
542
+ x = x.permute(0, 2, 3, 1).contiguous() # (B, 80, F//2, 2)
543
+ x = x.view(B, C//2, 2*f) # (B, 160, F)
544
+
545
+ x_mask = x_mask.unsqueeze(3).repeat(1, 1, 1, 2).view(B, 1, 2*f) # (B, 1, F//2, 1) -> (B, 1, F//2, 2) -> (B, 1, F)
546
+ x = x * x_mask # masking
547
+ return x, x_mask
548
+
549
+ class Encoder(nn.Module):
550
+ def __init__(self):
551
+ super().__init__()
552
+ self.embedding = nn.Embedding(symbol_length, 192) # (B, T) -> (B, T, 192)
553
+ nn.init.normal_(self.embedding.weight, 0.0, 192**(-0.5)) # 가중치 정규분포 초기화 (N(0, 0.07xx))
554
+
555
+ self.prenet = PreNet()
556
+ self.transformer_encoder = TransformerEncoder()
557
+ self.project_mean = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T)
558
+ self.project_std = nn.Conv1d(192, 80, kernel_size=1) # (B, 192, T) -> (B, 80, T)
559
+
560
+ self.duration_predictor = DurationPredictor()
561
+
562
+ def forward(self, text, text_len):
563
+ """
564
+ =====inputs=====
565
+ text: (B, Max_T)
566
+ text_len: (B)
567
+ =====outputs=====
568
+ x_mean: (B, 80, T) | 평균, 논문 저자 구현의 train.py에서 out_channels를 80으로 설정한 것을 알 수 있음.
569
+ x_std: (B, 80, T) | 표준편차
570
+ x_dur: (B, 1, T)
571
+ x_mask: (B, 1, T)
572
+ """
573
+ x = self.embedding(text) * math.sqrt(192) # (B, T) -> (B, T, 192) # math.sqrt(192) = 13.xx (수정)
574
+ x = x.transpose(1, 2) # (B, T, 192) -> (B, 192, T)
575
+
576
+ # Make the x_mask
577
+ x_mask = torch.zeros_like(x[:, 0:1, :], dtype=torch.bool) # (B, 1, T)
578
+ for idx, length in enumerate(text_len):
579
+ x_mask[idx, :, :length] = True
580
+
581
+ x = self.prenet(x, x_mask) # (B, 192, T)
582
+ x = self.transformer_encoder(x, x_mask) # (B, 192, T)
583
+
584
+ # project
585
+ x_mean = self.project_mean(x) * x_mask # (B, 192, T) -> (B, 80, T)
586
+ # x_std = self.project_std(x) * x_mask # (B, 192, T) -> (B, 80, T)
587
+ ##### 아래는 mean_only를 적용한 것임. #####
588
+ x_std = torch.zeros_like(x_mean) # x_log_std: (B, 80, T), all zero # log std = 0이므로 std = 1로 계산됨.
589
+
590
+ # duration predictor
591
+ x_dp = torch.detach(x) # stop_gradient
592
+ x_dur = self.duration_predictor(x_dp, x_mask) # (B, 192, T) -> (B, 1, T)
593
+
594
+ return x_mean, x_std, x_dur, x_mask
595
+
596
+ class LayerNorm(nn.Module):
597
+ """
598
+ 여러 곳에서 정규화(Norm)를 위해 사용되는 모듈.
599
+
600
+ nn.LayerNorm이 이미 pytorch 안에 구현되어 있으나, 항상 마지막 차원을 정규화한다.
601
+ 그래서 channel을 기준으로 정규화하는 LayerNorm을 따로 구현한다.
602
+ """
603
+ def __init__(self, channels):
604
+ """
605
+ channels: 입력 데이터의 channel 수 | LayerNorm은 channel 차원을 정규화한다.
606
+ """
607
+ super().__init__()
608
+ self.channels = channels
609
+ self.eps = 1e-4
610
+
611
+ self.gamma = nn.Parameter(torch.ones(channels)) # 학습 가능한 파라미터
612
+ self.beta = nn.Parameter(torch.zeros(channels)) # 학습 가능한 파라미터
613
+
614
+ def forward(self, x):
615
+ """
616
+ =====inputs=====
617
+ x: (B, channels, *) | 정규화할 입력 데이터
618
+ =====outputs=====
619
+ x: (B, channels, *) | channel 차원이 정규화된 데이터
620
+ """
621
+ mean = torch.mean(x, dim=1, keepdim=True) # channel 차원(index=1)의 평균 계산, 차원을 유지한다.
622
+ variance = torch.mean((x-mean)**2, dim=1, keepdim=True) # 분산 계산
623
+
624
+ x = (x - mean) * (variance + self.eps)**(-0.5) # (x - m) / sqrt(v)
625
+
626
+ n = len(x.shape)
627
+ shape = [1] * n
628
+ shape[1] = -1 # shape = [1, -1, 1] or [1, -1, 1, 1]
629
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape) # y = x*gamma + beta
630
+
631
+ return x
632
+
633
+ class PreNet(nn.Module):
634
+ """
635
+ Encoder의 1번째 모듈
636
+ """
637
+ def __init__(self):
638
+ super().__init__()
639
+ self.convs = nn.ModuleList()
640
+ self.norms = nn.ModuleList()
641
+ self.relu = nn.ReLU()
642
+ self.dropout = nn.Dropout(0.5)
643
+ for i in range(3):
644
+ self.convs.append(nn.Conv1d(192, 192, kernel_size=5, padding=2)) # (B, 192, T) 유지
645
+ self.norms.append(LayerNorm(192)) # (B, 192, T) 유지
646
+ self.linear = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지 | linear 역할을 하는 conv
647
+
648
+ def forward(self, x, x_mask):
649
+ """
650
+ =====inputs=====
651
+ x: (B, 192, T) | Embedding된 입력 데이터
652
+ x_mask: (B, 1, T) | 글자 길이에 따른 mask (글자가 있으면 True, 없으면 False로 구성)
653
+ =====outputs=====
654
+ x: (B, 192, T)
655
+ """
656
+ x0 = x
657
+ for i in range(3):
658
+ x = self.convs[i](x * x_mask)
659
+ x = self.norms[i](x)
660
+ x = self.relu(x)
661
+ x = self.dropout(x)
662
+ x = self.linear(x)
663
+ x = x0 + x # residual connection
664
+ return x
665
+
666
+ class MultiHeadAttention(nn.Module):
667
+ """
668
+ Encoder 중 2번째 모듈인 TransformerEncoder의 1번째 모듈
669
+ """
670
+ def __init__(self):
671
+ super().__init__()
672
+ self.n_heads = 2
673
+ self.window_size = 4
674
+ self.k_channels = 192 // self.n_heads # 96
675
+
676
+ self.linear_q = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
677
+ self.linear_k = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
678
+ self.linear_v = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
679
+ nn.init.xavier_uniform_(self.linear_q.weight)
680
+ nn.init.xavier_uniform_(self.linear_k.weight)
681
+ nn.init.xavier_uniform_(self.linear_v.weight)
682
+
683
+ relative_std = self.k_channels ** (-0.5) # 0.1xx
684
+ self.relative_k = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96)
685
+ self.relative_v = nn.Parameter(torch.randn(1, self.window_size * 2 + 1, self.k_channels) * relative_std) # (1, 9, 96)
686
+
687
+ self.attention_weights = None
688
+ self.linear_out = nn.Conv1d(192, 192, kernel_size=1) # (B, 192, T) 유지
689
+ self.dropout = nn.Dropout(0.1)
690
+
691
+ def forward(self, query, context, attention_mask, self_attention=True):
692
+ """
693
+ =====inputs=====
694
+ query: (B, 192, T_target) | Glow-TTS에서는 self-attention만 이용하므로 query와 context가 동일한 텐서 x이다.
695
+ context: (B, 192, T_source) | query = context || 여기에서는 특히 T_source = T_target 이다.
696
+ attention_mask: (B, 1, T, T) | x_mask.unsqueeze(2) * z_mask.unsqueeze(3)
697
+ self_attention: True/False | self_attention일 때 relative position representations를 적용한다. 여기에서는 항상 True이다.
698
+ # 실제로는 query와 context에 같은 텐서 x를 입력하면 된다.
699
+ =====outputs=====
700
+ output: (B, 192, T)
701
+ """
702
+
703
+ query = self.linear_q(query)
704
+ key = self.linear_k(context)
705
+ value = self.linear_v(context)
706
+
707
+ B, _, T_tar = query.size()
708
+ T_src = key.size(2)
709
+ query = query.view(B, self.n_heads, self.k_channels, T_tar).transpose(2, 3)
710
+ key = key.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3)
711
+ value = value.view(B, self.n_heads, self.k_channels, T_src).transpose(2, 3)
712
+ # (B, 192, T_src) -> (B, 2, 96, T_src) -> (B, 2, T_src, 96)
713
+
714
+ scores = torch.matmul(query, key.transpose(2, 3)) / (self.k_channels ** 0.5)
715
+ # (B, 2, T_tar, 96) * (B, 2, 96, T_src) -> (B, 2, T_tar, T_src)
716
+
717
+ if self_attention: # True
718
+ # Get relative embeddings (relative_keys) (1-1)
719
+ padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0)
720
+ start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0)
721
+ end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4)
722
+ relative_keys = F.pad(self.relative_k, (0, 0, padding, padding))
723
+ # (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96)
724
+ """
725
+ 위 코드의 F.pad(input, pad) 에서 pad = (0, 0, padding, padding)은 다음을 의미한다.
726
+ - 앞의 (0, 0): input의 -1차원을 앞으로 0, 뒤로 0만큼 패딩한다.
727
+ - 앞의 (padding, padding): input의 -2차원을 앞으로 padding, 뒤로 padding만큼 패딩한다.
728
+ 즉, F.pad에서 pad는 역순으로 생각해주어야 한다.
729
+ """
730
+ relative_keys = relative_keys[:, start_pos:end_pos, :] # (1, 2T-1, 96)
731
+
732
+ # Matmul with relative keys (2-1)
733
+ relative_keys = relative_keys.unsqueeze(0).transpose(2, 3) # (1, 2T-1, 96) -> (1, 1, 2T-1, 96) -> (1, 1, 96, 2T-1)
734
+ x = torch.matmul(query, relative_keys) # (B, 2, T_tar, 96) * (1, 1, 96, 2T_src-1) = (B, 2, T, 2T-1)
735
+ # self attention에서는 T_tar = T_src이므로 이를 다르게 고려할 필요가 없다.
736
+
737
+ # Relative position to absolute position (3-1)
738
+ T = T_tar # Absolute position to relative position에서도 쓰임.
739
+ x = F.pad(x, (0, 1)) # (B, 2, T, 2*T-1) -> (B, 2, T, 2*T)
740
+ x = x.view(B, self.n_heads, T * 2 * T) # (B, 2, T, 2*T) -> (B, 2. 2T^2)
741
+ x = F.pad(x, (0, T-1)) # (B, 2, 2T^2 + T - 1)
742
+ x = x.view(B, self.n_heads, T+1, 2*T-1) # (B, 2, T+1, 2T-1)
743
+ relative_logits = x[:, :, :T, T-1:] # (B, 2, T, T)
744
+
745
+ # Compute scores
746
+ scores_local = relative_logits / (self.k_channels ** 0.5)
747
+ scores = scores + scores_local # (B, 2, T, T)
748
+ """
749
+ 위 식은 Self-Attention with Relative Position Representations 논문의 5번 식을 구현한 것이다.
750
+ Relative- 논문: https://arxiv.org/pdf/1803.02155.pdf
751
+ """
752
+
753
+ scores = scores.masked_fill(attention_mask == 0, -1e-4) # attention_mask가 0인 곳을 -1e-4로 채운다.
754
+
755
+ attention_weights = F.softmax(scores, dim=-1) # (B, 2, T_tar, T_src) # Relative- 논문에서의 alpha에 해당한다.
756
+ attention_weights = self.dropout(attention_weights) # dropout하는 이유가 무엇일까?
757
+ output = torch.matmul(attention_weights, value) # (B, 2, T_tar, T_src) * (B, 2, T_src, 96) -> (B, 2, T_tar, 96)
758
+
759
+ if self_attention: # True
760
+ # Absolute position to relative position (3-2)
761
+ x = F.pad(attention_weights, (0, T-1)) # (B, 2, T, T) -> (B, 2, T, 2T-1)
762
+ x = x.view((B, self.n_heads, T * (2*T-1))) # (B, 2, 2T^2-T)
763
+ x = F.pad(x, (T, 0)) # (B, 2, 2T^2) # 앞에 패딩
764
+ x = x.view((B, self.n_heads, T, 2*T)) # (B, 2, T, 2T)
765
+ relative_weights = x[:, :, :, 1:] # (B, 2, T, 2T-1)
766
+
767
+ # Get relative embeddings (relative_value) (1-2) # (1-1)과 거의 동일
768
+ padding = max(T_src - (self.window_size + 1), 0) # max(T-5, 0)
769
+ start_pos = max((self.window_size + 1) - T_src, 0) # max(5-T, 0)
770
+ end_pos = start_pos + 2 * T_src - 1 # (2*T-1) or (T+4)
771
+ relative_values = F.pad(self.relative_v, (0, 0, padding, padding))
772
+ # (1, 9, 96) -> (1, pad+9+pad, 96) = (1, 2T-1, 96)
773
+ relative_values = relative_values[:, start_pos:end_pos, :] # (1, 2T-1, 96)
774
+
775
+ # Matmul with relative values (2-2)
776
+ relative_values = relative_values.unsqueeze(0) # (1, 1, 2T-1, 96)
777
+
778
+ output = output + torch.matmul(relative_weights, relative_values)
779
+ # (B, 2, T, 2T-1) * (1, 1, 2T-1, 96) = (B, 2, T, 96)
780
+ """
781
+ 위 식은 Self-Attention with Relative Position Representations 논문의 3번 식을 구현한 것이다. (분배법칙 이용)
782
+ Relative- 논문: https://arxiv.org/pdf/1803.02155.pdf
783
+ """
784
+
785
+ output = output.transpose(2, 3).contiguous().view(B, 192, T_tar)
786
+ # (B, 2, 96, T) -> 메모리에 연속 배치 -> (B, 192, T)
787
+
788
+ self.attention_weights = attention_weights # (B, 2, T, T)
789
+ output = self.linear_out(output)
790
+ return output # (B, 192, T)
791
+
792
+ class FFN(nn.Module):
793
+ """
794
+ Encoder 중 2번째 모듈인 TransformerEncoder의 2번째 모듈
795
+ """
796
+ def __init__(self):
797
+ super().__init__()
798
+ self.conv1 = nn.Conv1d(192, 768, kernel_size=3, padding=1) # (B, 192, T) -> (B, 768, T)
799
+ self.relu = nn.ReLU()
800
+ self.conv2 = nn.Conv1d(768, 192, kernel_size=3, padding=1) # (B, 768, T) -> (B, 192, T)
801
+ self.dropout = nn.Dropout(0.1)
802
+
803
+ def forward(self, x, x_mask):
804
+ """
805
+ =====inputs=====
806
+ x: (B, 192, T)
807
+ x_mask: (B, 1, T)
808
+ =====outputs=====
809
+ output: (B, 192, T)
810
+ """
811
+ x = self.conv1(x)
812
+ x = self.relu(x)
813
+ x = self.dropout(x)
814
+ x = self.conv2(x)
815
+ output = x * x_mask
816
+ return output
817
+
818
+ class TransformerEncoder(nn.Module):
819
+ """
820
+ Encoder의 2번째 모듈
821
+ """
822
+ def __init__(self):
823
+ super().__init__()
824
+ self.attentions = nn.ModuleList()
825
+ self.norms1 = nn.ModuleList()
826
+ self.ffns = nn.ModuleList()
827
+ self.norms2 = nn.ModuleList()
828
+ for i in range(6):
829
+ self.attentions.append(MultiHeadAttention())
830
+ self.norms1.append(LayerNorm(192))
831
+ self.ffns.append(FFN())
832
+ self.norms2.append(LayerNorm(192))
833
+ self.dropout = nn.Dropout(0.1)
834
+
835
+ def forward(self, x, x_mask):
836
+ """
837
+ =====inputs=====
838
+ x: (B, 192, T)
839
+ x_mask: (B, 1, T)
840
+ =====outputs=====
841
+ output: (B, 192, T)
842
+ """
843
+ attention_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(3)
844
+ # (B, 1, 1, T) * (B, 1, T, 1) = (B, 1, T, T), only consist 0 or 1
845
+ for i in range(6):
846
+ x = x * x_mask
847
+ y = self.attentions[i](x, x, attention_mask)
848
+ y = self.dropout(y)
849
+ x = x + y # residual connection
850
+ x = self.norms1[i](x) # (B, 192, T) 유지
851
+
852
+ y = self.ffns[i](x, x_mask)
853
+ y = self.dropout(y)
854
+ x = x + y # residual connection
855
+ x = self.norms2[i](x)
856
+ output = x * x_mask
857
+ return output # (B, 192, T)
858
+
859
+ class DurationPredictor(nn.Module):
860
+ """
861
+ Encoder의 3번째 모듈
862
+ """
863
+ def __init__(self):
864
+ super().__init__()
865
+ self.conv1 = nn.Conv1d(192, 256, kernel_size=3, padding=1) # (B, 192, T) -> (B, 256, T)
866
+ self.norm1 = LayerNorm(256)
867
+ self.conv2 = nn.Conv1d(256, 256, kernel_size=3, padding=1) # (B, 256, T) -> (B, 256, T)
868
+ self.norm2 = LayerNorm(256)
869
+ self.linear = nn.Conv1d(256, 1, kernel_size=1) # (B, 256, T) -> (B, 1, T)
870
+
871
+ self.relu = nn.ReLU()
872
+ self.dropout = nn.Dropout(0.1)
873
+
874
+ def forward(self, x, x_mask):
875
+ """
876
+ =====inputs=====
877
+ x: (B, 192, T)
878
+ x_mask: (B, 1, T)
879
+ =====outputs=====
880
+ output: (B, 1, T)
881
+ """
882
+ x = self.conv1(x * x_mask) # (B, 192, T) -> (B, 256, T)
883
+ x = self.relu(x)
884
+ x = self.norm1(x)
885
+ x = self.dropout(x)
886
+
887
+ x = self.conv2(x * x_mask) # (B, 256, T) -> (B, 256, T)
888
+ x = self.relu(x)
889
+ x = self.norm2(x)
890
+ x = self.dropout(x)
891
+
892
+ x = self.linear(x * x_mask) # (B, 256, T) -> (B, 1, T)
893
+ output = x * x_mask
894
+ return output