Spaces:
Runtime error
Runtime error
marigold334
commited on
Commit
•
41989ff
1
Parent(s):
397c86d
Upload 10 files
Browse files- Hmodel.py +289 -0
- app.py +135 -0
- commons.py +214 -0
- datautils.py +67 -0
- log/1038_eunsik_01/Glow_TTS_00289602.pt +3 -0
- log/1038_eunsik_01/HiFI_GAN_00257000.pt +3 -0
- model.py +154 -0
- module.py +299 -0
- requirements.txt +7 -0
Hmodel.py
ADDED
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm
|
5 |
+
|
6 |
+
# V2 model을 기준으로 한다.
|
7 |
+
class ResBlock(nn.Module):
|
8 |
+
def __init__(self, channels, kernel_size):
|
9 |
+
"""
|
10 |
+
channels:
|
11 |
+
kernel_size: 3, 7, 11 중 하나
|
12 |
+
"""
|
13 |
+
super(ResBlock, self).__init__()
|
14 |
+
# padding = (kernel_size-1)*dilation//2 ("same")
|
15 |
+
self.convs1 = nn.ModuleList([
|
16 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
|
17 |
+
padding=(kernel_size-1)*1//2)),
|
18 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
|
19 |
+
padding=(kernel_size-1)*1//2))
|
20 |
+
])
|
21 |
+
self.convs2 = nn.ModuleList([
|
22 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=3,
|
23 |
+
padding=(kernel_size-1)*3//2)),
|
24 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
|
25 |
+
padding=(kernel_size-1)*1//2))
|
26 |
+
])
|
27 |
+
self.convs3 = nn.ModuleList([
|
28 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=5,
|
29 |
+
padding=(kernel_size-1)*5//2)),
|
30 |
+
weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
|
31 |
+
padding=(kernel_size-1)*1//2))
|
32 |
+
])
|
33 |
+
self.modules = [self.convs1, self.convs2, self.convs3]
|
34 |
+
|
35 |
+
# 평균이 0, 표준편차가 0.01인 정규분포로 가중치 초기화
|
36 |
+
for module in self.modules:
|
37 |
+
for conv in module:
|
38 |
+
nn.init.normal_(conv.weight, mean=0.0, std=0.01)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
"""
|
42 |
+
=====inputs=====
|
43 |
+
x: (B, channels, F) # mel-spectrogram으로부터 얻어진 input features
|
44 |
+
=====outputs=====
|
45 |
+
x: (B, channels, F) # mel-spectrogram으로부터 얻어진 output features
|
46 |
+
"""
|
47 |
+
for module in self.modules:
|
48 |
+
for conv in module:
|
49 |
+
y = F.leaky_relu(x, 0.1)
|
50 |
+
y = conv(y)
|
51 |
+
x = x + y
|
52 |
+
return x
|
53 |
+
|
54 |
+
def remove_weight_norm(self):
|
55 |
+
for module in self.modules:
|
56 |
+
for conv in module:
|
57 |
+
remove_weight_norm(conv)
|
58 |
+
|
59 |
+
class MRF(nn.Module):
|
60 |
+
def __init__(self, channels):
|
61 |
+
"""
|
62 |
+
channels:
|
63 |
+
"""
|
64 |
+
super(MRF, self).__init__()
|
65 |
+
self.res_blocks = nn.ModuleList([
|
66 |
+
ResBlock(channels, kernel_size=3),
|
67 |
+
ResBlock(channels, kernel_size=7),
|
68 |
+
ResBlock(channels, kernel_size=11),
|
69 |
+
])
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""
|
73 |
+
=====inputs=====
|
74 |
+
x: (B, channels, F)
|
75 |
+
=====outputs=====
|
76 |
+
x: (B, channels, F)
|
77 |
+
"""
|
78 |
+
skip_list = []
|
79 |
+
for res_block in self.res_blocks:
|
80 |
+
skip_x = res_block(x)
|
81 |
+
skip_list.append(skip_x)
|
82 |
+
x = sum(skip_list) / len(self.res_blocks)
|
83 |
+
return x
|
84 |
+
|
85 |
+
def remove_weight_norm(self):
|
86 |
+
for block in self.res_blocks:
|
87 |
+
block.remove_weight_norm()
|
88 |
+
|
89 |
+
class Generator(nn.Module):
|
90 |
+
def __init__(self):
|
91 |
+
super(Generator, self).__init__()
|
92 |
+
self.pre_conv = weight_norm(nn.Conv1d(80, 128, kernel_size=7, stride=1, dilation=1,
|
93 |
+
padding=(7-1)//2)) # (B, 80, F) -> (B, 128, F)
|
94 |
+
nn.init.normal_(self.pre_conv.weight, mean=0.0, std=0.01) # 논문 저자 구현에는 없음.
|
95 |
+
|
96 |
+
self.up_convs = nn.ModuleList()
|
97 |
+
self.mrfs = nn.ModuleList()
|
98 |
+
ku = [16, 16, 4, 4]
|
99 |
+
for i in range(4):
|
100 |
+
# ku//2 배 upsampling
|
101 |
+
channels = 128//(2**(i+1))
|
102 |
+
up_conv = weight_norm(nn.ConvTranspose1d(128//(2**i), channels, kernel_size=ku[i], stride=ku[i]//2,
|
103 |
+
padding=(ku[i]-ku[i]//2)//2))
|
104 |
+
# (B, 128, F) -(1)-> (B, 64, F*8) -(2)-> (B, 32, F*8*8) -(3)-> (B, 16, F*8*8*2) -(4)-> (B, 8, F*8*8*2*2)
|
105 |
+
nn.init.normal_(up_conv.weight, mean=0.0, std=0.01)
|
106 |
+
self.up_convs.append(up_conv)
|
107 |
+
|
108 |
+
# MRF
|
109 |
+
mrf = MRF(channels) # (B, channels, F) -> (B, channels, F)
|
110 |
+
self.mrfs.append(mrf)
|
111 |
+
|
112 |
+
self.post_conv = weight_norm(nn.Conv1d(8, 1, kernel_size=7, stride=1, dilation=1,
|
113 |
+
padding=(7-1)//2)) # (B, 8, F*256) -> (B, 1, F*256)
|
114 |
+
nn.init.normal_(self.post_conv.weight, mean=0.0, std=0.01)
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
"""
|
118 |
+
=====inputs=====
|
119 |
+
x: (B, 80, F) # mel_spectrogram
|
120 |
+
=====outputs=====
|
121 |
+
x: (B, 1, F*256) # waveform
|
122 |
+
"""
|
123 |
+
x = self.pre_conv(x) # (B, 80, F) -> (B, 128, F)
|
124 |
+
for i in range(4):
|
125 |
+
x = F.leaky_relu(x, 0.1)
|
126 |
+
x = self.up_convs[i](x)
|
127 |
+
x = self.mrfs[i](x)
|
128 |
+
# final: (B, 128, F) -> (B, 8, F*256)
|
129 |
+
x = F.leaky_relu(x, 0.1)
|
130 |
+
x = self.post_conv(x) # (B, 8, F*256) -> (B, 1, F*256)
|
131 |
+
x = torch.tanh(x)
|
132 |
+
return x
|
133 |
+
|
134 |
+
def remove_weight_norm(self):
|
135 |
+
print('Removing weight norm...')
|
136 |
+
for l in self.up_convs:
|
137 |
+
remove_weight_norm(l)
|
138 |
+
for l in self.mrfs:
|
139 |
+
l.remove_weight_norm()
|
140 |
+
remove_weight_norm(self.pre_conv)
|
141 |
+
remove_weight_norm(self.post_conv)
|
142 |
+
|
143 |
+
class SubPD(nn.Module):
|
144 |
+
def __init__(self, period):
|
145 |
+
#period: 2, 3, 5, 7, 11 중 하나
|
146 |
+
super(SubPD, self).__init__()
|
147 |
+
self.period = period
|
148 |
+
|
149 |
+
self.convs = nn.ModuleList()
|
150 |
+
channels = 1
|
151 |
+
for i in range(1, 5): # 논문 저자의 변형 구현 대신 논문대로 구현함.
|
152 |
+
conv = weight_norm(nn.Conv2d(channels, 2**(5+i), kernel_size=(5, 1), stride=(3, 1), dilation=1, padding=0))
|
153 |
+
self.convs.append(conv)
|
154 |
+
channels = 2**(5+i)
|
155 |
+
# (B, 1, [T/p]+1, p) -(1)-> (B, 64, ?, p) -(2)-> (B, 128, ?, p) -(3)-> (B, 256, ?, p) -(4)-> (B, 512, ?, p)
|
156 |
+
last_conv = weight_norm(nn.Conv2d(channels, 1024, kernel_size=(5, 1), stride=(1, 1), dilation=1,
|
157 |
+
padding=(2, 0))) # (B, 512, ?, p) -> (B, 1024, ?, p)
|
158 |
+
self.convs.append(last_conv)
|
159 |
+
|
160 |
+
self.post_conv = weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1), stride=(1, 1), dilation=1,
|
161 |
+
padding=(1, 0))) # (B, 1024, ?, p) -> (B, 1, ?, p)
|
162 |
+
|
163 |
+
def forward(self, waveform):
|
164 |
+
"""
|
165 |
+
=====inputs=====
|
166 |
+
waveform: (B, 1, T)
|
167 |
+
=====outputs=====
|
168 |
+
x: (B, ?) # flatten된 real/fake 벡터 (0~1(?))
|
169 |
+
features: feature를 모두 모아놓은 list (Feature Matching Loss를 계산하기 위함.)
|
170 |
+
"""
|
171 |
+
features = []
|
172 |
+
|
173 |
+
B, _, T = waveform.size()
|
174 |
+
P = self.period
|
175 |
+
# padding
|
176 |
+
if T % P != 0:
|
177 |
+
padding = P - (T % P)
|
178 |
+
waveform = F.pad(waveform, (0, padding), "reflect") # 앞쪽에 0, 뒤쪽에 padding만큼 패딩, reflect는 마치 거울에 반사되듯이 패딩함.
|
179 |
+
# ex) [1, 2, 3, 4, 5]를 앞쪽에 2, 뒤쪽에 3만큼 reflect 모드로 padding -> [3, 2, 1, 2, 3, 4, 5, 4, 3, 2]
|
180 |
+
T += padding
|
181 |
+
# reshape
|
182 |
+
x = waveform.view(B, 1, T//P, P) # (B, 1, [T/P]+1, P)
|
183 |
+
|
184 |
+
for conv in self.convs:
|
185 |
+
x = conv(x)
|
186 |
+
x = F.leaky_relu(x, 0.1)
|
187 |
+
features.append(x)
|
188 |
+
x = self.post_conv(x)
|
189 |
+
features.append(x)
|
190 |
+
x = torch.flatten(x, 1, -1) # index 1번째 차원부터 마지막 차원까지 flatten | (B, ?)
|
191 |
+
##### sigmoid 함수나 cliping 과정을 거치지 않아도 되는가...?
|
192 |
+
return x, features
|
193 |
+
|
194 |
+
class MPD(nn.Module):
|
195 |
+
def __init__(self):
|
196 |
+
super(MPD, self).__init__()
|
197 |
+
self.sub_pds = nn.ModuleList([
|
198 |
+
SubPD(2), SubPD(3), SubPD(5), SubPD(7), SubPD(11),
|
199 |
+
]) # (B, 1, T) -> (B, ?), features list
|
200 |
+
|
201 |
+
def forward(self, real_waveform, gen_waveform):
|
202 |
+
"""
|
203 |
+
=====inputs=====
|
204 |
+
real_waveform: (B, 1, T) # 실제 음성
|
205 |
+
gen_waveform: (B, 1, T) # 생성 음성
|
206 |
+
=====outputs=====
|
207 |
+
real_outputs: (B, ?) list (len=5) # 실제 음성에 대한 SubPD outputs list
|
208 |
+
gen_outputs: (B, ?) list # 생성 음성에 대한 SubPD outputs list
|
209 |
+
real_features: features list # 실제 음성에 대한 SubPD features list
|
210 |
+
gen_features: features list # 생성 음성에 대한 SubPD features list
|
211 |
+
"""
|
212 |
+
real_outputs, gen_outputs, real_features, gen_features = [], [], [], []
|
213 |
+
for sub_pd in self.sub_pds:
|
214 |
+
real_output, real_feature = sub_pd(real_waveform)
|
215 |
+
gen_output, gen_feature = sub_pd(gen_waveform)
|
216 |
+
real_outputs.append(real_output)
|
217 |
+
gen_outputs.append(gen_output)
|
218 |
+
real_features.append(real_feature)
|
219 |
+
gen_features.append(gen_feature)
|
220 |
+
return real_outputs, gen_outputs, real_features, gen_features
|
221 |
+
|
222 |
+
class SubSD(nn.Module):
|
223 |
+
def __init__(self, first=False):
|
224 |
+
"""
|
225 |
+
first: boolean (first가 True이면 spectral normalization을 적용한다.)
|
226 |
+
"""
|
227 |
+
super(SubSD, self).__init__()
|
228 |
+
norm = spectral_norm if first else weight_norm # first가 True이면 spectral_norm, 그렇지 않으면 weight_norm
|
229 |
+
self.convs = nn.ModuleList([ # Mel-GAN 논문에 맞게 구현
|
230 |
+
norm(nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=(15-1)//2)), # (B, 1, T) -> (B, 16, T)
|
231 |
+
norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, groups=4, padding=(41-1)//2)), # (B, 16, T) -> (B, 64, T/4(?))
|
232 |
+
norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, groups=16, padding=(41-1)//2)), # (B, 64, T/4(?)) -> (B, 256, T/16(?))
|
233 |
+
norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, groups=64, padding=(41-1)//2)), # (B, 256, T/16(?)) -> (B, 1024, T/64(?))
|
234 |
+
norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, groups=256, padding=(41-1)//2)), # (B, 1024, T/64(?)) -> (B, 1024, T/256(?))
|
235 |
+
norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=(5-1)//2)) # (B, 1024, T/256(?)) -> (B, 1024, T/256(?))
|
236 |
+
])
|
237 |
+
self.post_conv = norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=(3-1)//2)) # (B, 1024, ?) -> (B, 1, ?)
|
238 |
+
|
239 |
+
def forward(self, waveform):
|
240 |
+
"""
|
241 |
+
=====inputs=====
|
242 |
+
waveform: (B, 1, T)
|
243 |
+
=====outputs=====
|
244 |
+
x: (B, ?) # flatten된 real/fake 벡터 (0~1(?))
|
245 |
+
features: feature를 모두 모아놓은 list (Feature Matching Loss를 계산하기 위함.)
|
246 |
+
"""
|
247 |
+
features = []
|
248 |
+
x = waveform
|
249 |
+
for conv in self.convs:
|
250 |
+
x = conv(x)
|
251 |
+
x = F.leaky_relu(x, 0.1)
|
252 |
+
features.append(x)
|
253 |
+
x = self.post_conv(x) # (B, 1, ?)
|
254 |
+
features.append(x)
|
255 |
+
x = x.squeeze(1) # (B, ?)
|
256 |
+
##### sigmoid 함수나 cliping 과정을 거치지 않아도 되는가...?
|
257 |
+
return x, features
|
258 |
+
|
259 |
+
class MSD(nn.Module):
|
260 |
+
def __init__(self):
|
261 |
+
super(MSD, self).__init__()
|
262 |
+
self.sub_sds = nn.ModuleList([
|
263 |
+
SubSD(first=True), SubSD(), SubSD()
|
264 |
+
]) # (B, 1, T) -> (B, ?), features list
|
265 |
+
self.avgpool = nn.AvgPool1d(kernel_size=4, stride=2, padding=2) # x2 down sampling
|
266 |
+
|
267 |
+
def forward(self, real_waveform, gen_waveform):
|
268 |
+
"""
|
269 |
+
=====inputs=====
|
270 |
+
real_waveform: (B, 1, T) # 실제 음성
|
271 |
+
gen_waveform: (B, 1, T) # 생성 음성
|
272 |
+
=====outputs=====
|
273 |
+
real_outputs: (B, ?) list (len=3) # 실제 음성에 대한 SubSD outputs list
|
274 |
+
gen_outputs: (B, ?) list # 생성 음성에 대한 SubSD outputs list
|
275 |
+
real_features: features list # 실제 음성에 대한 SubSD features list
|
276 |
+
gen_features: features list # 생성 음성에 대한 SubSD features list
|
277 |
+
"""
|
278 |
+
real_outputs, gen_outputs, real_features, gen_features = [], [], [], []
|
279 |
+
for idx, sub_sd in enumerate(self.sub_sds):
|
280 |
+
if idx != 0:
|
281 |
+
real_waveform = self.avgpool(real_waveform)
|
282 |
+
gen_waveform = self.avgpool(gen_waveform)
|
283 |
+
real_output, real_feature = sub_sd(real_waveform)
|
284 |
+
gen_output, gen_feature = sub_sd(gen_waveform)
|
285 |
+
real_outputs.append(real_output)
|
286 |
+
gen_outputs.append(gen_output)
|
287 |
+
real_features.append(real_feature)
|
288 |
+
gen_features.append(gen_feature)
|
289 |
+
return real_outputs, gen_outputs, real_features, gen_features
|
app.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import soundfile as sf
|
3 |
+
import timeit
|
4 |
+
import uuid
|
5 |
+
|
6 |
+
import os
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from datautils import *
|
11 |
+
from model import Generator as Glow_model
|
12 |
+
from utils import scan_checkpoint, plot_mel, plot_alignment
|
13 |
+
from Hmodel import Generator as GAN_model
|
14 |
+
|
15 |
+
MAX_WAV_VALUE = 32768.0
|
16 |
+
device = torch.device('cuda:0')
|
17 |
+
torch.cuda.manual_seed(1234)
|
18 |
+
name = '1038_eunsik_01'
|
19 |
+
|
20 |
+
# Nix
|
21 |
+
from nix.models.TTS import NixTTSInference
|
22 |
+
|
23 |
+
def init_session_state():
|
24 |
+
# Model
|
25 |
+
if "init_model" not in st.session_state:
|
26 |
+
st.session_state.init_model = True
|
27 |
+
st.session_state.model_variant = "KSS"
|
28 |
+
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
|
29 |
+
|
30 |
+
def update_model():
|
31 |
+
if st.session_state.model_variant == "KSS":
|
32 |
+
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-v0.1")
|
33 |
+
elif st.session_state.model_variant == "은식":
|
34 |
+
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
|
35 |
+
|
36 |
+
def update_session_state(state_id, state_value):
|
37 |
+
st.session_state[f"{state_id}"] = state_value
|
38 |
+
|
39 |
+
def centered_text(input_text, mode = "h1",):
|
40 |
+
st.markdown(
|
41 |
+
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
|
42 |
+
|
43 |
+
def generate_voice(input_text,):
|
44 |
+
# TTS Inference
|
45 |
+
c, c_length, phoneme = st.session_state.TTS.tokenize(input_text)
|
46 |
+
voice = st.session_state.TTS.vocalize(c, c_length)
|
47 |
+
|
48 |
+
# Save audio (bug in Streamlit, can't play numpy array directly)
|
49 |
+
sf.write(f"cache_sound/{input_text}.wav", voice[0,0], 22050)
|
50 |
+
|
51 |
+
# Play audio
|
52 |
+
st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
|
53 |
+
os.remove(f"cache_sound/{input_text}.wav")
|
54 |
+
st.caption("Generated Voice")
|
55 |
+
|
56 |
+
st.set_page_config(
|
57 |
+
page_title = "소신 Team Demo",
|
58 |
+
page_icon = "🔉",
|
59 |
+
)
|
60 |
+
|
61 |
+
init_session_state()
|
62 |
+
|
63 |
+
centered_text("🔉 소신 Team Demo")
|
64 |
+
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
|
65 |
+
st.write(" ")
|
66 |
+
|
67 |
+
mode = "p"
|
68 |
+
st.markdown(
|
69 |
+
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind 3 times \"은식\" is finetuned from \"KSS\" for 3 times We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
|
70 |
+
unsafe_allow_html = True
|
71 |
+
)
|
72 |
+
|
73 |
+
st.write(" ")
|
74 |
+
st.write(" ")
|
75 |
+
col1, col2 = st.columns(2)
|
76 |
+
|
77 |
+
with col1:
|
78 |
+
input_text = st.text_input(
|
79 |
+
"한글로만 입력해주세요",
|
80 |
+
value = "딥러닝은 정말 재밌어!",
|
81 |
+
)
|
82 |
+
with col2:
|
83 |
+
model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "은식"], index = 1)
|
84 |
+
if model_variant != st.session_state.model_variant:
|
85 |
+
# Update variant choice
|
86 |
+
update_session_state("model_variant", model_variant)
|
87 |
+
# Re-load model
|
88 |
+
update_model()
|
89 |
+
|
90 |
+
button_gen = st.button("Generate Voice")
|
91 |
+
if button_gen == True:
|
92 |
+
generate_voice(input_text)
|
93 |
+
|
94 |
+
|
95 |
+
class TTS:
|
96 |
+
def __init__(self, model_variant):
|
97 |
+
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
|
98 |
+
self.voicegenerator = GAN_model()
|
99 |
+
if model_variant == '은식':
|
100 |
+
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
101 |
+
check_point = torch.load(last_chpt1)
|
102 |
+
self.flowgenerator.load_state_dict(check_point['generator'])
|
103 |
+
self.flowgenerator.decoder.skip()
|
104 |
+
self.flowgenerator.eval()
|
105 |
+
if model_variant == '은식':
|
106 |
+
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
|
107 |
+
check_point = torch.load(last_chpt2)
|
108 |
+
self.voicegenerator.load_state_dict(check_point['gen_model'])
|
109 |
+
self.voicegenerator.eval()
|
110 |
+
self.voicegenerator.remove_weight_norm()
|
111 |
+
|
112 |
+
def inference(self, input_text):
|
113 |
+
x = text_to_sequence(sentence)
|
114 |
+
filters = '([.,!?])'
|
115 |
+
sentence = re.sub(re.compile(filters), '', text)
|
116 |
+
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
|
117 |
+
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
|
118 |
+
|
119 |
+
with torch.no_grad():
|
120 |
+
noise_scale = .667
|
121 |
+
length_scale = 1.0
|
122 |
+
(y_gen_tst, *_), *_, (attn_gen, *_) = flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
|
123 |
+
y = voicegenerator(y_gen_tst)
|
124 |
+
audio = y.squeeze() * MAX_WAV_VALUE
|
125 |
+
audio = audio.cpu().numpy().astype('int16')
|
126 |
+
|
127 |
+
output_file = os.path.join(out_dir, 'gen_'+text[:3]+'.wav')
|
128 |
+
write(output_file, 22050, audio)
|
129 |
+
print(f'{text} is stored in {out_dir}')
|
130 |
+
|
131 |
+
return voice
|
132 |
+
plot_mel(y_gen_tst[0].data.cpu().numpy())
|
133 |
+
plot_alignment(attn_gen[0,0].data.cpu().numpy(), sequence_to_text(x[0].data.cpu().numpy()))
|
134 |
+
ipd.display(fig1,fig2)
|
135 |
+
ipd.Audio(filename=output_file)
|
commons.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from module import *
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
|
7 |
+
|
8 |
+
def sequence_mask(length, max_length=None):
|
9 |
+
if max_length is None:
|
10 |
+
max_length = max(length)
|
11 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
12 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
13 |
+
|
14 |
+
|
15 |
+
def maximum_path(value, mask, max_neg_val=-np.inf):
|
16 |
+
""" Numpy-friendly version. It's about 4 times faster than torch version.
|
17 |
+
value: [b, t_x, t_y]
|
18 |
+
mask: [b, t_x, t_y]
|
19 |
+
"""
|
20 |
+
value = value * mask
|
21 |
+
|
22 |
+
device = value.device
|
23 |
+
dtype = value.dtype
|
24 |
+
value = value.cpu().detach().numpy()
|
25 |
+
mask = mask.cpu().detach().numpy().astype(np.bool)
|
26 |
+
|
27 |
+
b, t_x, t_y = value.shape
|
28 |
+
direction = np.zeros(value.shape, dtype=np.int64)
|
29 |
+
v = np.zeros((b, t_x), dtype=np.float32)
|
30 |
+
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
|
31 |
+
|
32 |
+
for j in range(t_y):
|
33 |
+
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
|
34 |
+
v1 = v
|
35 |
+
max_mask = (v1 >= v0)
|
36 |
+
v_max = np.where(max_mask, v1, v0)
|
37 |
+
direction[:, :, j] = max_mask
|
38 |
+
|
39 |
+
index_mask = (x_range <= j)
|
40 |
+
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
|
41 |
+
direction = np.where(mask, direction, 1)
|
42 |
+
|
43 |
+
path = np.zeros(value.shape, dtype=np.float32)
|
44 |
+
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
|
45 |
+
index_range = np.arange(b)
|
46 |
+
|
47 |
+
|
48 |
+
for j in reversed(range(t_y)):
|
49 |
+
path[index_range, index, j] = 1
|
50 |
+
index = index + direction[index_range, index, j] - 1
|
51 |
+
|
52 |
+
|
53 |
+
path = path * mask.astype(np.float32)
|
54 |
+
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
55 |
+
return path
|
56 |
+
|
57 |
+
|
58 |
+
def generate_path(duration, mask):
|
59 |
+
"""
|
60 |
+
duration: [b, t_x]
|
61 |
+
mask: [b, t_x, t_y]
|
62 |
+
"""
|
63 |
+
device = duration.device
|
64 |
+
|
65 |
+
b, t_x, t_y = mask.shape
|
66 |
+
cum_duration = torch.cumsum(duration, 1)
|
67 |
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
68 |
+
|
69 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
70 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
71 |
+
path = path.view(b, t_x, t_y)
|
72 |
+
path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1]
|
73 |
+
path = path * mask
|
74 |
+
return path
|
75 |
+
|
76 |
+
def mle_loss(z, m, logs, logdet, mask):
|
77 |
+
# neg normal likelihood w/o the constant term
|
78 |
+
l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2))
|
79 |
+
l = l - torch.sum(logdet) # log jacobian determinant
|
80 |
+
# averaging across batch, channel and time axes
|
81 |
+
l = l / torch.sum(torch.ones_like(z) * mask)
|
82 |
+
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
83 |
+
return l
|
84 |
+
|
85 |
+
|
86 |
+
def duration_loss(logw, logw_, lengths):
|
87 |
+
l = torch.sum((logw - logw_)**2) / torch.sum(lengths)
|
88 |
+
return l
|
89 |
+
|
90 |
+
|
91 |
+
class AttrDict(dict):
|
92 |
+
def __init__(self, *args, **kwargs):
|
93 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
94 |
+
self.__dict__ = self
|
95 |
+
|
96 |
+
def GAN_Loss_Generator(gen_outputs):
|
97 |
+
"""
|
98 |
+
gen_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
|
99 |
+
"""
|
100 |
+
loss = 0
|
101 |
+
for DG in gen_outputs:
|
102 |
+
loss += torch.mean((DG-1)**2)
|
103 |
+
return loss
|
104 |
+
|
105 |
+
def GAN_Loss_Discriminator(real_outputs, gen_outputs):
|
106 |
+
"""
|
107 |
+
real_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
|
108 |
+
gen_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
|
109 |
+
"""
|
110 |
+
loss = 0
|
111 |
+
for D, DG in zip(real_outputs, gen_outputs):
|
112 |
+
loss += torch.mean((D-1)**2 + DG**2)
|
113 |
+
return loss
|
114 |
+
|
115 |
+
def Mel_Spectrogram_Loss(real_mel, gen_mel):
|
116 |
+
"""
|
117 |
+
real_mel: (B, F, 80) # Dataloader로부터 가져온 mel-spectrogram
|
118 |
+
gen_mel: (B, F, 80) # Generator가 생성한 waveform의 mel-spectrogram
|
119 |
+
"""
|
120 |
+
loss = F.l1_loss(real_mel, gen_mel)
|
121 |
+
return 45*loss
|
122 |
+
|
123 |
+
def Feature_Matching_Loss(real_features, gen_features):
|
124 |
+
"""
|
125 |
+
real_features: (?, ..., ?) list of list # MPD(len=[5, 6]) 또는 MSD(len=[3, 7])의 출력
|
126 |
+
gen_features: (?, ..., ?) list of list # MPD(len=[5, 6]) 또는 MSD(len=[3, 7])의 출력
|
127 |
+
"""
|
128 |
+
loss = 0
|
129 |
+
for Ds, DGs in zip(real_features, gen_features):
|
130 |
+
for D, DG in zip(Ds, DGs):
|
131 |
+
loss += torch.mean(torch.abs(D - DG))
|
132 |
+
return 2*loss
|
133 |
+
|
134 |
+
def Final_Loss_Generator(mpd_gen_outputs, mpd_real_features, mpd_gen_features,
|
135 |
+
msd_gen_outputs, msd_real_features, msd_gen_features,
|
136 |
+
real_mel, gen_mel):
|
137 |
+
"""
|
138 |
+
=====inputs=====
|
139 |
+
[:3]: MPD outputs 뒤쪽 3개
|
140 |
+
[3:6]: MSD outputs 뒤쪽 3개
|
141 |
+
[7:8]: real_mel and gen_mel
|
142 |
+
=====outputs=====
|
143 |
+
Generator_Loss
|
144 |
+
Mel_Loss
|
145 |
+
"""
|
146 |
+
Gen_Adv1 = GAN_Loss_Generator(mpd_gen_outputs)
|
147 |
+
Gen_Adv2 = GAN_Loss_Generator(msd_gen_outputs)
|
148 |
+
Adv = Gen_Adv1 + Gen_Adv2
|
149 |
+
FM1 = Feature_Matching_Loss(mpd_real_features, mpd_gen_features)
|
150 |
+
FM2 = Feature_Matching_Loss(msd_real_features, msd_gen_features)
|
151 |
+
FM = FM1 + FM2
|
152 |
+
Mel_Loss = Mel_Spectrogram_Loss(real_mel, gen_mel)
|
153 |
+
Generator_Loss = Adv + FM + Mel_Loss
|
154 |
+
|
155 |
+
return Generator_Loss, Mel_Loss , Adv, FM
|
156 |
+
|
157 |
+
def Final_Loss_Discriminator(mpd_real_outputs, mpd_gen_outputs,
|
158 |
+
msd_real_outputs, msd_gen_outputs):
|
159 |
+
"""
|
160 |
+
=====inputs=====
|
161 |
+
[:2]: MPD outputs 앞쪽 2개
|
162 |
+
[2:4]: MSD outputs 앞쪽 2개
|
163 |
+
=====outputs=====
|
164 |
+
Discriminator_Loss
|
165 |
+
"""
|
166 |
+
Disc_Adv1 = GAN_Loss_Discriminator(mpd_real_outputs, mpd_gen_outputs)
|
167 |
+
Disc_Adv2 = GAN_Loss_Discriminator(msd_real_outputs, msd_gen_outputs)
|
168 |
+
Discriminator_Loss = Disc_Adv1 + Disc_Adv2
|
169 |
+
|
170 |
+
return Discriminator_Loss
|
171 |
+
|
172 |
+
class Adam():
|
173 |
+
def __init__(self, params, scheduler, dim_model, warmup_steps=4000, lr=1e0, betas=(0.9, 0.98), eps=1e-9):
|
174 |
+
self.params = params
|
175 |
+
self.scheduler = scheduler
|
176 |
+
self.dim_model = dim_model
|
177 |
+
self.warmup_steps = warmup_steps
|
178 |
+
self.lr = lr
|
179 |
+
self.betas = betas
|
180 |
+
self.eps = eps
|
181 |
+
self.step_num = 1
|
182 |
+
self.cur_lr = lr * self._get_lr_scale()
|
183 |
+
|
184 |
+
self._optim = torch.optim.Adam(params, lr=self.cur_lr, betas=betas, eps=eps)
|
185 |
+
self.param_groups = self._optim.param_groups
|
186 |
+
|
187 |
+
def _get_lr_scale(self):
|
188 |
+
if self.scheduler == "noam":
|
189 |
+
return np.power(self.dim_model, -0.5) * np.min([np.power(self.step_num, -0.5), self.step_num * np.power(self.warmup_steps, -1.5)])
|
190 |
+
else:
|
191 |
+
return 1
|
192 |
+
|
193 |
+
def _update_learning_rate(self):
|
194 |
+
self.step_num += 1
|
195 |
+
if self.scheduler == "noam":
|
196 |
+
self.cur_lr = self.lr * self._get_lr_scale()
|
197 |
+
for param_group in self._optim.param_groups:
|
198 |
+
param_group['lr'] = self.cur_lr
|
199 |
+
|
200 |
+
def get_lr(self):
|
201 |
+
return self.cur_lr
|
202 |
+
|
203 |
+
def step(self):
|
204 |
+
self._optim.step()
|
205 |
+
self._update_learning_rate()
|
206 |
+
|
207 |
+
def zero_grad(self):
|
208 |
+
self._optim.zero_grad()
|
209 |
+
|
210 |
+
def load_state_dict(self, d):
|
211 |
+
self._optim.load_state_dict(d)
|
212 |
+
|
213 |
+
def state_dict(self):
|
214 |
+
return self._optim.state_dict()
|
datautils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from jamo import hangul_to_jamo
|
2 |
+
import librosa
|
3 |
+
import torch
|
4 |
+
|
5 |
+
sample_rate = 22050
|
6 |
+
preemphasis = 0.97
|
7 |
+
n_fft = 1024
|
8 |
+
hop_length = 256
|
9 |
+
win_length = 1024
|
10 |
+
ref_db = 20
|
11 |
+
max_db = 100
|
12 |
+
mel_dim = 80
|
13 |
+
|
14 |
+
PAD = '_'
|
15 |
+
EOS = '~'
|
16 |
+
SPACE = ' '
|
17 |
+
|
18 |
+
JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
|
19 |
+
JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
|
20 |
+
JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])
|
21 |
+
|
22 |
+
VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + SPACE
|
23 |
+
symbols = PAD + EOS + VALID_CHARS
|
24 |
+
|
25 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
26 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
27 |
+
|
28 |
+
# text를 초성, 중성, 종성으로 분리하여 id로 반환하는 함수
|
29 |
+
def text_to_sequence(text):
|
30 |
+
sequence = []
|
31 |
+
if not 0x1100 <= ord(text[0]) <= 0x1113:
|
32 |
+
text = ''.join(list(hangul_to_jamo(text)))
|
33 |
+
for s in text:
|
34 |
+
sequence.append(_symbol_to_id[s])
|
35 |
+
sequence.append(_symbol_to_id['~'])
|
36 |
+
return sequence
|
37 |
+
|
38 |
+
def sequence_to_text(sequence):
|
39 |
+
result = ''
|
40 |
+
for symbol_id in sequence:
|
41 |
+
if symbol_id in _id_to_symbol:
|
42 |
+
s = _id_to_symbol[symbol_id]
|
43 |
+
result += s
|
44 |
+
return result.replace('}{', ' ')
|
45 |
+
|
46 |
+
def mel_spectrogram(y, n_fft=1024, num_mels=80, sampling_rate=22050, hop_size=256, win_size=1024, fmin=0, fmax=8000, center=False):
|
47 |
+
"""
|
48 |
+
if torch.min(y) < -1.:
|
49 |
+
print('min value is ', torch.min(y))
|
50 |
+
if torch.max(y) > 1.:
|
51 |
+
print('max value is ', torch.max(y))
|
52 |
+
"""
|
53 |
+
|
54 |
+
mel = librosa.filters.mel(sampling_rate, n_fft, num_mels, fmin, fmax)
|
55 |
+
|
56 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
57 |
+
y = y.squeeze(1)
|
58 |
+
|
59 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=torch.hann_window(win_size).to(y.device),
|
60 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
61 |
+
|
62 |
+
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
63 |
+
|
64 |
+
spec = torch.matmul(torch.from_numpy(mel).float().to(y.device), spec)
|
65 |
+
spec = torch.log(torch.clamp(spec, min=1e-5) * 1)
|
66 |
+
|
67 |
+
return spec
|
log/1038_eunsik_01/Glow_TTS_00289602.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7c84ba863d822d97db5c00dc4e69fcdf15d9040c456cab6503179bb220748cee
|
3 |
+
size 279930587
|
log/1038_eunsik_01/HiFI_GAN_00257000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2db9dd4f1fa06c40d98a19ac4e1740e390d66de5cef3138d670a29d8f917bddb
|
3 |
+
size 421187547
|
model.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from module import *
|
5 |
+
from commons import *
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
class Generator(nn.Module):
|
10 |
+
def __init__(self, n_vocab, h_c, f_c, f_c_dp, out_c, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6):
|
11 |
+
super().__init__()
|
12 |
+
self.encoder = Encoder(n_vocab, out_c, h_c, f_c, f_c_dp, heads= heads, layers = layers_enc, k_s = k_s)
|
13 |
+
self.decoder = Decoder(in_c = out_c, hi_c = h_c, k_s = k_s_dec)
|
14 |
+
|
15 |
+
def forward(self, x, x_lengths, y=None, y_lengths=None, gen = False, noise_scale=1., length_scale=1.):
|
16 |
+
x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths)
|
17 |
+
if gen:
|
18 |
+
w = torch.exp(logw) * x_mask * length_scale
|
19 |
+
w_ceil = torch.ceil(w)
|
20 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
21 |
+
y_max_length = None
|
22 |
+
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
|
23 |
+
|
24 |
+
z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
25 |
+
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
|
26 |
+
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
27 |
+
|
28 |
+
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
29 |
+
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
30 |
+
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
|
31 |
+
|
32 |
+
z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask
|
33 |
+
y, logdet = self.decoder(z, z_mask, reverse=True)
|
34 |
+
return (y, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)
|
35 |
+
|
36 |
+
else:
|
37 |
+
y_max_length = y.size(2)
|
38 |
+
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
|
39 |
+
z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
40 |
+
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
|
41 |
+
|
42 |
+
z, logdet = self.decoder(y, z_mask, reverse=False)
|
43 |
+
with torch.no_grad():
|
44 |
+
x_s_sq_r = torch.exp(-2 * x_logs)
|
45 |
+
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1]
|
46 |
+
logp2 = torch.matmul(x_s_sq_r.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
47 |
+
logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
48 |
+
logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(-1) # [b, t, 1]
|
49 |
+
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
50 |
+
|
51 |
+
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
52 |
+
z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
53 |
+
z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
54 |
+
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
|
55 |
+
|
56 |
+
return (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)
|
57 |
+
|
58 |
+
def preprocess(self, y, y_lengths, y_max_length):
|
59 |
+
if y_max_length is not None:
|
60 |
+
y_max_length = (y_max_length // 2) * 2
|
61 |
+
y = y[:,:,:y_max_length]
|
62 |
+
y_lengths = (y_lengths // 2) * 2
|
63 |
+
return y, y_lengths, y_max_length
|
64 |
+
|
65 |
+
|
66 |
+
class Encoder(nn.Module):
|
67 |
+
def __init__(self, n_vocab, out_c, h_c, f_c, f_c_dp, heads, layers, k_s, p=0.1, mean_only = True):
|
68 |
+
super().__init__()
|
69 |
+
self.h_c = h_c
|
70 |
+
self.mean_only = mean_only
|
71 |
+
self.emb = nn.Embedding(n_vocab, h_c)
|
72 |
+
nn.init.normal_(self.emb.weight, 0.0, h_c**(-0.5))
|
73 |
+
self.prenet = Prenet(in_c = h_c, hi_c = h_c, out_c = h_c, k_s = 5)
|
74 |
+
self.drop = nn.Dropout(p=p)
|
75 |
+
self.atten_layers = nn.ModuleList()
|
76 |
+
self.norm_layers = nn.ModuleList()
|
77 |
+
self.ffn_layers = nn.ModuleList()
|
78 |
+
for i in range(layers):
|
79 |
+
self.atten_layers.append(MultiheadAttention(h_c, h_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None))
|
80 |
+
self.norm_layers.extend([Layernorm(h_c), Layernorm(h_c)])
|
81 |
+
self.ffn_layers.append(FFN(h_c, f_c, k_s, p))
|
82 |
+
self.proj_m = nn.Conv1d(h_c, out_c, 1)
|
83 |
+
if not mean_only:
|
84 |
+
self.proj_s = nn.Conv1d(h_c, out_c, 1)
|
85 |
+
self.proj_w = DurationPredictor(h_c, f_c_dp, k_s, p=p)
|
86 |
+
|
87 |
+
def forward(self,x, x_length):
|
88 |
+
x = self.emb(x) * torch.sqrt(torch.tensor(self.h_c)) # [b,t,h]
|
89 |
+
x = torch.transpose(x, 1, -1) # [b,h,t]
|
90 |
+
x_mask = torch.unsqueeze(sequence_mask(x_length, x.size(2)), 1).to(x.dtype)
|
91 |
+
|
92 |
+
x = self.prenet(x, x_mask)
|
93 |
+
atten_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
94 |
+
for i in range(len(self.atten_layers)):
|
95 |
+
x = x * x_mask
|
96 |
+
y = self.drop(self.atten_layers[i](x, atten_mask))
|
97 |
+
x = self.norm_layers[2*i](x+y)
|
98 |
+
|
99 |
+
y = self.drop(self.ffn_layers[i](x, x_mask))
|
100 |
+
x = self.norm_layers[2*i+1](x+y)
|
101 |
+
x = x*x_mask
|
102 |
+
|
103 |
+
x_m = self.proj_m(x)
|
104 |
+
if not self.mean_only:
|
105 |
+
x_logs = self.proj_m(x)
|
106 |
+
else:
|
107 |
+
x_logs = torch.zeros_like(x_m)
|
108 |
+
logw = self.proj_w(x.detach(), x_mask)
|
109 |
+
|
110 |
+
return x_m, x_logs, logw, x_mask
|
111 |
+
|
112 |
+
class Decoder(nn.Module):
|
113 |
+
def __init__(self, in_c, hi_c, k_s, d_l =1 , blocks = 12, splits = 4,):
|
114 |
+
super().__init__()
|
115 |
+
self.flows = nn.ModuleList()
|
116 |
+
for _ in range(blocks):
|
117 |
+
self.flows.extend([ActNorm(in_c*2), InvConvNear(splits = splits), Couplinglayer(in_c*2, hi_c, k_s, d_l = d_l)])
|
118 |
+
|
119 |
+
def forward(self, x, x_mask = None, reverse = False):
|
120 |
+
if not reverse:
|
121 |
+
flows = self.flows
|
122 |
+
tot_logdet = 0
|
123 |
+
else:
|
124 |
+
flows = reversed(self.flows)
|
125 |
+
tot_logdet = None
|
126 |
+
|
127 |
+
b, c, t = x.shape
|
128 |
+
t = t - t%2
|
129 |
+
if x_mask is None:
|
130 |
+
mask = torch.ones(b,1,t//2)
|
131 |
+
else:
|
132 |
+
mask = x_mask[:,:,1::2]
|
133 |
+
x = x[:,:,:t].reshape(b, c, t//2, 2).transpose(2,3).contiguous().reshape(b,2*c,t//2) * mask # [b, 2c, t/2]
|
134 |
+
for f in flows:
|
135 |
+
x, logdet = f(x, mask, reverse = reverse)
|
136 |
+
if not reverse:
|
137 |
+
tot_logdet = tot_logdet + logdet
|
138 |
+
if x_mask is None:
|
139 |
+
mask = torch.ones(b,1,t)
|
140 |
+
else:
|
141 |
+
mask = x_mask[:,:,:t]
|
142 |
+
x = x.reshape(b,c,2,t//2).transpose(2,3).contiguous().reshape(b,c,t) * mask # [b, c, t]
|
143 |
+
return x, tot_logdet
|
144 |
+
|
145 |
+
def skip(self):
|
146 |
+
for f in self.flows:
|
147 |
+
f.skip()
|
148 |
+
|
149 |
+
def ddi_init(self):
|
150 |
+
for i, f in enumerate(self.flows):
|
151 |
+
if i % 3 == 0:
|
152 |
+
f.set_ddi()
|
153 |
+
|
154 |
+
|
module.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
######################################### encoder ##############################################
|
6 |
+
|
7 |
+
class Layernorm(nn.Module):
|
8 |
+
def __init__(self, channels):
|
9 |
+
super().__init__()
|
10 |
+
self.gamma = nn.Parameter(torch.ones(1, channels))
|
11 |
+
self.beta = nn.Parameter(torch.zeros(1, channels))
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
m = torch.mean(x, dim = 1, keepdim = True)
|
15 |
+
v = torch.mean((x-m)**2, dim = 1, keepdim = True)
|
16 |
+
x = (x - m) * torch.rsqrt(v + 1e-4) # normarlization
|
17 |
+
n = len(x.shape)
|
18 |
+
shape = [1, -1] + [1]*(n-2)
|
19 |
+
x = x*self.gamma.reshape(*shape) + self.beta.reshape(*shape)
|
20 |
+
return x
|
21 |
+
|
22 |
+
class Prenet(nn.Module):
|
23 |
+
def __init__(self, in_c, hi_c, out_c, k_s = 5, layers =3, p = 0.05):
|
24 |
+
super().__init__()
|
25 |
+
self.crn = nn.ModuleList()
|
26 |
+
self.crn.extend([nn.Conv1d(in_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
|
27 |
+
self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
|
28 |
+
self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
|
29 |
+
|
30 |
+
self.proj = nn.Conv1d(hi_c, out_c, 1)
|
31 |
+
self.proj.weight.data.zero_()
|
32 |
+
self.proj.bias.data.zero_()
|
33 |
+
|
34 |
+
def forward(self, start, x_mask=1):
|
35 |
+
x = start
|
36 |
+
for layer in self.crn:
|
37 |
+
x = layer(x) # [b. c. t]
|
38 |
+
x = x * x_mask
|
39 |
+
x = self.proj(x) + start # [b. c. t]
|
40 |
+
end = x * x_mask
|
41 |
+
return end # [b. c. t]
|
42 |
+
|
43 |
+
class MultiheadAttention(nn.Module):
|
44 |
+
def __init__(self, c, out_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None,):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.k = c // heads
|
48 |
+
self.window_size = window_size
|
49 |
+
self.proj_q = nn.Conv1d(c,c,1)
|
50 |
+
self.proj_k = nn.Conv1d(c,c,1)
|
51 |
+
self.proj_v = nn.Conv1d(c,c,1)
|
52 |
+
|
53 |
+
nn.init.xavier_uniform_(self.proj_q.weight)
|
54 |
+
nn.init.xavier_uniform_(self.proj_k.weight)
|
55 |
+
nn.init.xavier_uniform_(self.proj_v.weight)
|
56 |
+
|
57 |
+
n_heads_rel = 1 if heads_share else heads
|
58 |
+
self.d_k = (self.k)**(-0.5)
|
59 |
+
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k)
|
60 |
+
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k)
|
61 |
+
|
62 |
+
self.conv_o = nn.Conv1d(c, out_c, 1)
|
63 |
+
self.drop = nn.Dropout(p=p)
|
64 |
+
|
65 |
+
def forward(self, x, attn_mask=None):
|
66 |
+
query, key, value = self.proj_q(x), self.proj_k(x), self.proj_v(x)
|
67 |
+
b, c, t = query.shape
|
68 |
+
h, k = c // self.k, self.k
|
69 |
+
|
70 |
+
query = query.reshape(b,h,k,t)
|
71 |
+
key = key.reshape(b,h,k,t)
|
72 |
+
value = value.reshape(b,h,k,t)
|
73 |
+
|
74 |
+
matrix = self.get_relative_matrix(self.emb_rel_k, t)
|
75 |
+
rel_logit = torch.matmul(matrix.unsqueeze(0), query) # [1,1,2t-1,k] * [b,h,k,t] = [b,h,2t-1,t]
|
76 |
+
abs_logit = self.rel_to_abs(rel_logit.transpose(2,3))
|
77 |
+
local_score = abs_logit * self.d_k
|
78 |
+
|
79 |
+
score = torch.matmul(query.transpose(2,3), key) * self.d_k + local_score
|
80 |
+
if attn_mask is not None:
|
81 |
+
score = score.masked_fill(attn_mask == 0, -1e4)
|
82 |
+
|
83 |
+
align = F.softmax(score, dim = -1)
|
84 |
+
atten = self.drop(align)
|
85 |
+
self.atten = atten
|
86 |
+
|
87 |
+
matrix = self.get_relative_matrix(self.emb_rel_v, t).transpose(1,2) # [1,k,2t-1]
|
88 |
+
weight = self.abs_to_rel(atten).transpose(2,3) # [b,h,2t-1,t]
|
89 |
+
output = torch.matmul(value, atten) + torch.matmul(matrix.unsqueeze(0), weight) # [b,h,k,t]
|
90 |
+
x = self.conv_o(output.contiguous().reshape(b,c,t))
|
91 |
+
return x
|
92 |
+
|
93 |
+
def get_relative_matrix(self, emb_rel_k, t):
|
94 |
+
s = self.window_size
|
95 |
+
pad_size = max(t - s - 1, 0)
|
96 |
+
start = max(s+1-t, 0)
|
97 |
+
emb_rel_k = F.pad(emb_rel_k, (0,0, pad_size, pad_size))
|
98 |
+
return emb_rel_k[:,start:start+2*t+1]
|
99 |
+
|
100 |
+
def rel_to_abs(self, x):
|
101 |
+
b,h,t,_= x.shape
|
102 |
+
x = F.pad(x, (0,1)).reshape(b,h,2*t*t)
|
103 |
+
x = F.pad(x, (0,t-1)).reshape(b,h,t+1, 2*t-1)[:,:,:t,t-1:]
|
104 |
+
return x
|
105 |
+
|
106 |
+
def abs_to_rel(self, x):
|
107 |
+
b,h,t,t = x.shape
|
108 |
+
x = F.pad(x, (0, t-1)).reshape(b,h,2*t*t-t)
|
109 |
+
x = F.pad(x, (t,0)).reshape(b,h,t,2*t)[:,:,:,1:]
|
110 |
+
return x
|
111 |
+
|
112 |
+
class FFN(nn.Module):
|
113 |
+
def __init__(self, h_c, f_c, k_s, p = 0.1):
|
114 |
+
super().__init__()
|
115 |
+
self.conv1 = nn.Conv1d(h_c, f_c, k_s, padding=k_s//2)
|
116 |
+
self.conv2 = nn.Conv1d(f_c, h_c, k_s, padding=k_s//2)
|
117 |
+
self.drop = nn.Dropout(p=p)
|
118 |
+
def forward(self, x, x_mask = None):
|
119 |
+
x = self.conv2(self.drop(F.relu(self.conv1(x*x_mask)))*x_mask)
|
120 |
+
return x * x_mask
|
121 |
+
|
122 |
+
class DurationPredictor(nn.Module):
|
123 |
+
def __init__(self, in_c, f_c, k_s, p=0.1):
|
124 |
+
super().__init__()
|
125 |
+
self.block1 = nn.Sequential(nn.Conv1d(in_c, f_c, k_s, padding=k_s//2),
|
126 |
+
nn.ReLU(),
|
127 |
+
Layernorm(f_c),
|
128 |
+
nn.Dropout(p=p))
|
129 |
+
self.block2 = nn.Sequential(nn.Conv1d(f_c, f_c, k_s, padding=k_s//2),
|
130 |
+
nn.ReLU(),
|
131 |
+
Layernorm(f_c),
|
132 |
+
nn.Dropout(p=p))
|
133 |
+
self.proj = nn.Conv1d(f_c, 1, 1)
|
134 |
+
def forward(self, x, x_mask):
|
135 |
+
x = self.block1(x * x_mask)
|
136 |
+
x = self.block2(x * x_mask)
|
137 |
+
x = self.proj(x * x_mask)
|
138 |
+
return x * x_mask
|
139 |
+
|
140 |
+
######################################### decoder ##############################################
|
141 |
+
# static file system(reasoning the type of tensor), optimizing computation graph, complie before functioning >> to accelate the speed
|
142 |
+
|
143 |
+
@torch.jit.script
|
144 |
+
def fuse_tan_sig_add(x:torch.Tensor, mid:int) -> torch.Tensor:
|
145 |
+
a, b = x[:, :mid, :], x[:, mid:, :]
|
146 |
+
return torch.sigmoid(a) * torch.tanh(b)
|
147 |
+
|
148 |
+
class WN(nn.Module): # non-casual wavenet without dilation
|
149 |
+
def __init__(self, hi_c, k_s, d_l = 1, layers = 3, p=0.05):
|
150 |
+
super().__init__()
|
151 |
+
self.hi_c = hi_c
|
152 |
+
self.resblocks=nn.ModuleList()
|
153 |
+
self.skipblocks=nn.ModuleList()
|
154 |
+
self.drop = nn.Dropout(p=p)
|
155 |
+
for _ in range(layers):
|
156 |
+
res_layer = nn.Conv1d(hi_c, 2*hi_c, k_s, dilation=d_l, padding=k_s//2)
|
157 |
+
res_layer = nn.utils.weight_norm(res_layer, name = 'weight')
|
158 |
+
self.resblocks.append(res_layer)
|
159 |
+
if _ ==2:
|
160 |
+
skip_layer = nn.Conv1d(hi_c, hi_c, 1) # last layer
|
161 |
+
else:
|
162 |
+
skip_layer = nn.Conv1d(hi_c, 2*hi_c, 1)
|
163 |
+
skip_layer = nn.utils.weight_norm(skip_layer, name = 'weight')
|
164 |
+
self.skipblocks.append(skip_layer)
|
165 |
+
|
166 |
+
def forward(self, x, x_mask = None):
|
167 |
+
mid = self.hi_c
|
168 |
+
end = torch.zeros_like(x, dtype=x.dtype)
|
169 |
+
for i in range(len(self.resblocks)):
|
170 |
+
x = self.drop(self.resblocks[i](x)) # [b, 2c, t]
|
171 |
+
x = fuse_tan_sig_add(x, mid) # [b, c, t]
|
172 |
+
y = self.skipblocks[i](x)
|
173 |
+
if i == 2:
|
174 |
+
end = end + y # last layer
|
175 |
+
else:
|
176 |
+
x = (x + y[:, :mid, :]) * x_mask
|
177 |
+
end = end + y[:, mid:, :]
|
178 |
+
return end * x_mask
|
179 |
+
|
180 |
+
def skip(self):
|
181 |
+
for layer1, layer2 in zip(self.resblocks, self.skipblocks):
|
182 |
+
nn.utils.remove_weight_norm(layer1)
|
183 |
+
nn.utils.remove_weight_norm(layer2)
|
184 |
+
|
185 |
+
class Couplinglayer(nn.Module):
|
186 |
+
def __init__(self, in_c, hi_c, k_s, d_l = 1):
|
187 |
+
super().__init__()
|
188 |
+
s_proj = nn.Conv1d(in_c//2, hi_c, 1)
|
189 |
+
self.start = nn.utils.weight_norm(s_proj, name = 'weight')
|
190 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
191 |
+
# do nothing at first. It helps to stabilze training. from glow paper
|
192 |
+
self.end = nn.Conv1d(hi_c, in_c, 1)
|
193 |
+
self.end.weight.data.zero_()
|
194 |
+
self.end.bias.data.zero_()
|
195 |
+
self.wn = WN(hi_c, k_s, d_l)
|
196 |
+
|
197 |
+
# y = x * logs + t
|
198 |
+
def forward(self, x, x_mask=None, reverse = False):
|
199 |
+
if x_mask is None:
|
200 |
+
x_mask = 1
|
201 |
+
mid = x.shape[1]//2 # divide channels by 2
|
202 |
+
x_0, x_1 = x[:, :mid, :], x[:, mid:, :]
|
203 |
+
z_1 = self.end(self.wn(self.start(x_1) * x_mask, x_mask))
|
204 |
+
logs, t = z_1[:,mid:,:], z_1[:, :mid, :]
|
205 |
+
if reverse:
|
206 |
+
x_0 = torch.exp(-logs)*(x_0 - t) * x_mask
|
207 |
+
logdet = None
|
208 |
+
else :
|
209 |
+
x_0 = torch.exp(logs + 1e-4) * x_0 + t
|
210 |
+
logdet = torch.sum(logs * x_mask, [1,2]) # sum(log(s))
|
211 |
+
z = torch.cat([x_0, x_1], dim = 1)
|
212 |
+
return z, logdet
|
213 |
+
def skip(self):
|
214 |
+
self.wn.skip()
|
215 |
+
|
216 |
+
class InvConvNear(nn.Module):
|
217 |
+
def __init__(self, splits = 4):
|
218 |
+
super().__init__()
|
219 |
+
self.splits = splits
|
220 |
+
w_init = torch.linalg.qr(torch.randn((splits, splits)).normal_())[0] # othonormal vector matrix
|
221 |
+
if torch.det(w_init) < 0:
|
222 |
+
w_init[0,:] = -w_init[0,:]
|
223 |
+
self.weight = nn.Parameter(w_init)
|
224 |
+
|
225 |
+
def forward(self, x, x_mask=None, reverse = False):
|
226 |
+
b, c, t = x.shape
|
227 |
+
|
228 |
+
if x_mask is None:
|
229 |
+
x_mask = 1
|
230 |
+
x_len = torch.ones(b) * t # [b]
|
231 |
+
else:
|
232 |
+
x_len = torch.sum(x_mask, [1,2])
|
233 |
+
|
234 |
+
s = self.splits
|
235 |
+
x = x.reshape(b, 2, c//s, s//2, t) # split channels into 2 groups
|
236 |
+
x = x.permute(0,1,3,2,4).contiguous().reshape(b, s, c//s, t)
|
237 |
+
|
238 |
+
if reverse:
|
239 |
+
if hasattr(self, "weight_inv"):
|
240 |
+
weight = self.weight_inv
|
241 |
+
weight = torch.inverse(self.weight).to(dtype=self.weight.dtype)
|
242 |
+
logdet = None
|
243 |
+
else:
|
244 |
+
weight = self.weight
|
245 |
+
logdet = torch.logdet(weight) * (c//s) * x_len # h*w*log(det(W)) since there's no necesserity for decomposition
|
246 |
+
|
247 |
+
weight = weight.unsqueeze(-1).unsqueeze(-1)
|
248 |
+
z = F.conv2d(x, weight) # z = matmul(weight, x_i,j) for i,j in h = c//s, w = t
|
249 |
+
|
250 |
+
z = z.reshape(b, 2, s//2, c//s, t).permute(0,1,3,2,4).contiguous().reshape(b, c, t) * x_mask
|
251 |
+
return z, logdet
|
252 |
+
|
253 |
+
def skip(self):
|
254 |
+
self.weigth_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
255 |
+
|
256 |
+
|
257 |
+
class ActNorm(nn.Module):
|
258 |
+
def __init__(self, hi_c, ddi = False): # data dependent initialization
|
259 |
+
super().__init__()
|
260 |
+
self.logs = nn.Parameter(torch.zeros(1, hi_c, 1))
|
261 |
+
self.bias = nn.Parameter(torch.zeros(1, hi_c, 1))
|
262 |
+
self.ddi = ddi
|
263 |
+
|
264 |
+
def forward(self, x, x_mask = None, reverse = False):
|
265 |
+
b, _, t = x.shape
|
266 |
+
if x_mask is None:
|
267 |
+
x_mask = torch.ones(b,1,t).to(device= x.device, dtype = x.dtype)
|
268 |
+
x_len = torch.sum(x_mask, [1, 2])
|
269 |
+
if self.ddi:
|
270 |
+
self.initialize(x, x_mask)
|
271 |
+
self.ddi = False
|
272 |
+
# y = exp(logs) * x + bias > normalization in channel dim
|
273 |
+
if reverse:
|
274 |
+
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
275 |
+
logdet = None
|
276 |
+
else:
|
277 |
+
z = (torch.exp(self.logs) * x + self.bias) * x_mask
|
278 |
+
logdet = torch.sum(self.logs, [1,2])* x_len
|
279 |
+
return z, logdet
|
280 |
+
|
281 |
+
def initialize(self, x, x_mask):
|
282 |
+
with torch.no_grad():
|
283 |
+
n = torch.sum(x_mask, [0,2])
|
284 |
+
m = torch.sum(x * x_mask, [0,2])/n
|
285 |
+
m_s = torch.sum(x * x * x_mask, [0,2])/n
|
286 |
+
v = m_s - m**2
|
287 |
+
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
288 |
+
|
289 |
+
init_bias = (-m/torch.exp(-logs)).reshape(*self.bias.shape).to(dtype = self.bias.dtype) # -m/s
|
290 |
+
init_logs = (-logs).reshape(*self.logs.shape).to(dtype = self.logs.dtype) # -logs
|
291 |
+
|
292 |
+
self.bias.data.copy_(init_bias)
|
293 |
+
self.logs.data.copy_(init_logs)
|
294 |
+
|
295 |
+
def set_ddi(self):
|
296 |
+
self.ddi = True
|
297 |
+
|
298 |
+
def skip(self):
|
299 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.4.0
|
2 |
+
numpy==1.17.4
|
3 |
+
librosa==0.7.2
|
4 |
+
scipy==1.4.1
|
5 |
+
tensorboard==2.0
|
6 |
+
soundfile==0.10.3.post1
|
7 |
+
matplotlib==3.1.3
|