marigold334 commited on
Commit
41989ff
1 Parent(s): 397c86d

Upload 10 files

Browse files
Hmodel.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.utils import weight_norm, spectral_norm, remove_weight_norm
5
+
6
+ # V2 model을 기준으로 한다.
7
+ class ResBlock(nn.Module):
8
+ def __init__(self, channels, kernel_size):
9
+ """
10
+ channels:
11
+ kernel_size: 3, 7, 11 중 하나
12
+ """
13
+ super(ResBlock, self).__init__()
14
+ # padding = (kernel_size-1)*dilation//2 ("same")
15
+ self.convs1 = nn.ModuleList([
16
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
17
+ padding=(kernel_size-1)*1//2)),
18
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
19
+ padding=(kernel_size-1)*1//2))
20
+ ])
21
+ self.convs2 = nn.ModuleList([
22
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=3,
23
+ padding=(kernel_size-1)*3//2)),
24
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
25
+ padding=(kernel_size-1)*1//2))
26
+ ])
27
+ self.convs3 = nn.ModuleList([
28
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=5,
29
+ padding=(kernel_size-1)*5//2)),
30
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, stride=1, dilation=1,
31
+ padding=(kernel_size-1)*1//2))
32
+ ])
33
+ self.modules = [self.convs1, self.convs2, self.convs3]
34
+
35
+ # 평균이 0, 표준편차가 0.01인 정규분포로 가중치 초기화
36
+ for module in self.modules:
37
+ for conv in module:
38
+ nn.init.normal_(conv.weight, mean=0.0, std=0.01)
39
+
40
+ def forward(self, x):
41
+ """
42
+ =====inputs=====
43
+ x: (B, channels, F) # mel-spectrogram으로부터 얻어진 input features
44
+ =====outputs=====
45
+ x: (B, channels, F) # mel-spectrogram으로부터 얻어진 output features
46
+ """
47
+ for module in self.modules:
48
+ for conv in module:
49
+ y = F.leaky_relu(x, 0.1)
50
+ y = conv(y)
51
+ x = x + y
52
+ return x
53
+
54
+ def remove_weight_norm(self):
55
+ for module in self.modules:
56
+ for conv in module:
57
+ remove_weight_norm(conv)
58
+
59
+ class MRF(nn.Module):
60
+ def __init__(self, channels):
61
+ """
62
+ channels:
63
+ """
64
+ super(MRF, self).__init__()
65
+ self.res_blocks = nn.ModuleList([
66
+ ResBlock(channels, kernel_size=3),
67
+ ResBlock(channels, kernel_size=7),
68
+ ResBlock(channels, kernel_size=11),
69
+ ])
70
+
71
+ def forward(self, x):
72
+ """
73
+ =====inputs=====
74
+ x: (B, channels, F)
75
+ =====outputs=====
76
+ x: (B, channels, F)
77
+ """
78
+ skip_list = []
79
+ for res_block in self.res_blocks:
80
+ skip_x = res_block(x)
81
+ skip_list.append(skip_x)
82
+ x = sum(skip_list) / len(self.res_blocks)
83
+ return x
84
+
85
+ def remove_weight_norm(self):
86
+ for block in self.res_blocks:
87
+ block.remove_weight_norm()
88
+
89
+ class Generator(nn.Module):
90
+ def __init__(self):
91
+ super(Generator, self).__init__()
92
+ self.pre_conv = weight_norm(nn.Conv1d(80, 128, kernel_size=7, stride=1, dilation=1,
93
+ padding=(7-1)//2)) # (B, 80, F) -> (B, 128, F)
94
+ nn.init.normal_(self.pre_conv.weight, mean=0.0, std=0.01) # 논문 저자 구현에는 없음.
95
+
96
+ self.up_convs = nn.ModuleList()
97
+ self.mrfs = nn.ModuleList()
98
+ ku = [16, 16, 4, 4]
99
+ for i in range(4):
100
+ # ku//2 배 upsampling
101
+ channels = 128//(2**(i+1))
102
+ up_conv = weight_norm(nn.ConvTranspose1d(128//(2**i), channels, kernel_size=ku[i], stride=ku[i]//2,
103
+ padding=(ku[i]-ku[i]//2)//2))
104
+ # (B, 128, F) -(1)-> (B, 64, F*8) -(2)-> (B, 32, F*8*8) -(3)-> (B, 16, F*8*8*2) -(4)-> (B, 8, F*8*8*2*2)
105
+ nn.init.normal_(up_conv.weight, mean=0.0, std=0.01)
106
+ self.up_convs.append(up_conv)
107
+
108
+ # MRF
109
+ mrf = MRF(channels) # (B, channels, F) -> (B, channels, F)
110
+ self.mrfs.append(mrf)
111
+
112
+ self.post_conv = weight_norm(nn.Conv1d(8, 1, kernel_size=7, stride=1, dilation=1,
113
+ padding=(7-1)//2)) # (B, 8, F*256) -> (B, 1, F*256)
114
+ nn.init.normal_(self.post_conv.weight, mean=0.0, std=0.01)
115
+
116
+ def forward(self, x):
117
+ """
118
+ =====inputs=====
119
+ x: (B, 80, F) # mel_spectrogram
120
+ =====outputs=====
121
+ x: (B, 1, F*256) # waveform
122
+ """
123
+ x = self.pre_conv(x) # (B, 80, F) -> (B, 128, F)
124
+ for i in range(4):
125
+ x = F.leaky_relu(x, 0.1)
126
+ x = self.up_convs[i](x)
127
+ x = self.mrfs[i](x)
128
+ # final: (B, 128, F) -> (B, 8, F*256)
129
+ x = F.leaky_relu(x, 0.1)
130
+ x = self.post_conv(x) # (B, 8, F*256) -> (B, 1, F*256)
131
+ x = torch.tanh(x)
132
+ return x
133
+
134
+ def remove_weight_norm(self):
135
+ print('Removing weight norm...')
136
+ for l in self.up_convs:
137
+ remove_weight_norm(l)
138
+ for l in self.mrfs:
139
+ l.remove_weight_norm()
140
+ remove_weight_norm(self.pre_conv)
141
+ remove_weight_norm(self.post_conv)
142
+
143
+ class SubPD(nn.Module):
144
+ def __init__(self, period):
145
+ #period: 2, 3, 5, 7, 11 중 하나
146
+ super(SubPD, self).__init__()
147
+ self.period = period
148
+
149
+ self.convs = nn.ModuleList()
150
+ channels = 1
151
+ for i in range(1, 5): # 논문 저자의 변형 구현 대신 논문대로 구현함.
152
+ conv = weight_norm(nn.Conv2d(channels, 2**(5+i), kernel_size=(5, 1), stride=(3, 1), dilation=1, padding=0))
153
+ self.convs.append(conv)
154
+ channels = 2**(5+i)
155
+ # (B, 1, [T/p]+1, p) -(1)-> (B, 64, ?, p) -(2)-> (B, 128, ?, p) -(3)-> (B, 256, ?, p) -(4)-> (B, 512, ?, p)
156
+ last_conv = weight_norm(nn.Conv2d(channels, 1024, kernel_size=(5, 1), stride=(1, 1), dilation=1,
157
+ padding=(2, 0))) # (B, 512, ?, p) -> (B, 1024, ?, p)
158
+ self.convs.append(last_conv)
159
+
160
+ self.post_conv = weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1), stride=(1, 1), dilation=1,
161
+ padding=(1, 0))) # (B, 1024, ?, p) -> (B, 1, ?, p)
162
+
163
+ def forward(self, waveform):
164
+ """
165
+ =====inputs=====
166
+ waveform: (B, 1, T)
167
+ =====outputs=====
168
+ x: (B, ?) # flatten된 real/fake 벡터 (0~1(?))
169
+ features: feature를 모두 모아놓은 list (Feature Matching Loss를 계산하기 위함.)
170
+ """
171
+ features = []
172
+
173
+ B, _, T = waveform.size()
174
+ P = self.period
175
+ # padding
176
+ if T % P != 0:
177
+ padding = P - (T % P)
178
+ waveform = F.pad(waveform, (0, padding), "reflect") # 앞쪽에 0, 뒤쪽에 padding만큼 패딩, reflect는 마치 거울에 반사되듯이 패딩함.
179
+ # ex) [1, 2, 3, 4, 5]를 앞쪽에 2, 뒤쪽에 3만큼 reflect 모드로 padding -> [3, 2, 1, 2, 3, 4, 5, 4, 3, 2]
180
+ T += padding
181
+ # reshape
182
+ x = waveform.view(B, 1, T//P, P) # (B, 1, [T/P]+1, P)
183
+
184
+ for conv in self.convs:
185
+ x = conv(x)
186
+ x = F.leaky_relu(x, 0.1)
187
+ features.append(x)
188
+ x = self.post_conv(x)
189
+ features.append(x)
190
+ x = torch.flatten(x, 1, -1) # index 1번째 차원부터 마지막 차원까지 flatten | (B, ?)
191
+ ##### sigmoid 함수나 cliping 과정을 거치지 않아도 되는가...?
192
+ return x, features
193
+
194
+ class MPD(nn.Module):
195
+ def __init__(self):
196
+ super(MPD, self).__init__()
197
+ self.sub_pds = nn.ModuleList([
198
+ SubPD(2), SubPD(3), SubPD(5), SubPD(7), SubPD(11),
199
+ ]) # (B, 1, T) -> (B, ?), features list
200
+
201
+ def forward(self, real_waveform, gen_waveform):
202
+ """
203
+ =====inputs=====
204
+ real_waveform: (B, 1, T) # 실제 음성
205
+ gen_waveform: (B, 1, T) # 생성 음성
206
+ =====outputs=====
207
+ real_outputs: (B, ?) list (len=5) # 실제 음성에 대한 SubPD outputs list
208
+ gen_outputs: (B, ?) list # 생성 음성에 대한 SubPD outputs list
209
+ real_features: features list # 실제 음성에 대한 SubPD features list
210
+ gen_features: features list # 생성 음성에 대한 SubPD features list
211
+ """
212
+ real_outputs, gen_outputs, real_features, gen_features = [], [], [], []
213
+ for sub_pd in self.sub_pds:
214
+ real_output, real_feature = sub_pd(real_waveform)
215
+ gen_output, gen_feature = sub_pd(gen_waveform)
216
+ real_outputs.append(real_output)
217
+ gen_outputs.append(gen_output)
218
+ real_features.append(real_feature)
219
+ gen_features.append(gen_feature)
220
+ return real_outputs, gen_outputs, real_features, gen_features
221
+
222
+ class SubSD(nn.Module):
223
+ def __init__(self, first=False):
224
+ """
225
+ first: boolean (first가 True이면 spectral normalization을 적용한다.)
226
+ """
227
+ super(SubSD, self).__init__()
228
+ norm = spectral_norm if first else weight_norm # first가 True이면 spectral_norm, 그렇지 않으면 weight_norm
229
+ self.convs = nn.ModuleList([ # Mel-GAN 논문에 맞게 구현
230
+ norm(nn.Conv1d(1, 16, kernel_size=15, stride=1, padding=(15-1)//2)), # (B, 1, T) -> (B, 16, T)
231
+ norm(nn.Conv1d(16, 64, kernel_size=41, stride=4, groups=4, padding=(41-1)//2)), # (B, 16, T) -> (B, 64, T/4(?))
232
+ norm(nn.Conv1d(64, 256, kernel_size=41, stride=4, groups=16, padding=(41-1)//2)), # (B, 64, T/4(?)) -> (B, 256, T/16(?))
233
+ norm(nn.Conv1d(256, 1024, kernel_size=41, stride=4, groups=64, padding=(41-1)//2)), # (B, 256, T/16(?)) -> (B, 1024, T/64(?))
234
+ norm(nn.Conv1d(1024, 1024, kernel_size=41, stride=4, groups=256, padding=(41-1)//2)), # (B, 1024, T/64(?)) -> (B, 1024, T/256(?))
235
+ norm(nn.Conv1d(1024, 1024, kernel_size=5, stride=1, padding=(5-1)//2)) # (B, 1024, T/256(?)) -> (B, 1024, T/256(?))
236
+ ])
237
+ self.post_conv = norm(nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=(3-1)//2)) # (B, 1024, ?) -> (B, 1, ?)
238
+
239
+ def forward(self, waveform):
240
+ """
241
+ =====inputs=====
242
+ waveform: (B, 1, T)
243
+ =====outputs=====
244
+ x: (B, ?) # flatten된 real/fake 벡터 (0~1(?))
245
+ features: feature를 모두 모아놓은 list (Feature Matching Loss를 계산하기 위함.)
246
+ """
247
+ features = []
248
+ x = waveform
249
+ for conv in self.convs:
250
+ x = conv(x)
251
+ x = F.leaky_relu(x, 0.1)
252
+ features.append(x)
253
+ x = self.post_conv(x) # (B, 1, ?)
254
+ features.append(x)
255
+ x = x.squeeze(1) # (B, ?)
256
+ ##### sigmoid 함수나 cliping 과정을 거치지 않아도 되는가...?
257
+ return x, features
258
+
259
+ class MSD(nn.Module):
260
+ def __init__(self):
261
+ super(MSD, self).__init__()
262
+ self.sub_sds = nn.ModuleList([
263
+ SubSD(first=True), SubSD(), SubSD()
264
+ ]) # (B, 1, T) -> (B, ?), features list
265
+ self.avgpool = nn.AvgPool1d(kernel_size=4, stride=2, padding=2) # x2 down sampling
266
+
267
+ def forward(self, real_waveform, gen_waveform):
268
+ """
269
+ =====inputs=====
270
+ real_waveform: (B, 1, T) # 실제 음성
271
+ gen_waveform: (B, 1, T) # 생성 음성
272
+ =====outputs=====
273
+ real_outputs: (B, ?) list (len=3) # 실제 음성에 대한 SubSD outputs list
274
+ gen_outputs: (B, ?) list # 생성 음성에 대한 SubSD outputs list
275
+ real_features: features list # 실제 음성에 대한 SubSD features list
276
+ gen_features: features list # 생성 음성에 대한 SubSD features list
277
+ """
278
+ real_outputs, gen_outputs, real_features, gen_features = [], [], [], []
279
+ for idx, sub_sd in enumerate(self.sub_sds):
280
+ if idx != 0:
281
+ real_waveform = self.avgpool(real_waveform)
282
+ gen_waveform = self.avgpool(gen_waveform)
283
+ real_output, real_feature = sub_sd(real_waveform)
284
+ gen_output, gen_feature = sub_sd(gen_waveform)
285
+ real_outputs.append(real_output)
286
+ gen_outputs.append(gen_output)
287
+ real_features.append(real_feature)
288
+ gen_features.append(gen_feature)
289
+ return real_outputs, gen_outputs, real_features, gen_features
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import soundfile as sf
3
+ import timeit
4
+ import uuid
5
+
6
+ import os
7
+
8
+ import torch
9
+
10
+ from datautils import *
11
+ from model import Generator as Glow_model
12
+ from utils import scan_checkpoint, plot_mel, plot_alignment
13
+ from Hmodel import Generator as GAN_model
14
+
15
+ MAX_WAV_VALUE = 32768.0
16
+ device = torch.device('cuda:0')
17
+ torch.cuda.manual_seed(1234)
18
+ name = '1038_eunsik_01'
19
+
20
+ # Nix
21
+ from nix.models.TTS import NixTTSInference
22
+
23
+ def init_session_state():
24
+ # Model
25
+ if "init_model" not in st.session_state:
26
+ st.session_state.init_model = True
27
+ st.session_state.model_variant = "KSS"
28
+ st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
29
+
30
+ def update_model():
31
+ if st.session_state.model_variant == "KSS":
32
+ st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-v0.1")
33
+ elif st.session_state.model_variant == "은식":
34
+ st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
35
+
36
+ def update_session_state(state_id, state_value):
37
+ st.session_state[f"{state_id}"] = state_value
38
+
39
+ def centered_text(input_text, mode = "h1",):
40
+ st.markdown(
41
+ f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
42
+
43
+ def generate_voice(input_text,):
44
+ # TTS Inference
45
+ c, c_length, phoneme = st.session_state.TTS.tokenize(input_text)
46
+ voice = st.session_state.TTS.vocalize(c, c_length)
47
+
48
+ # Save audio (bug in Streamlit, can't play numpy array directly)
49
+ sf.write(f"cache_sound/{input_text}.wav", voice[0,0], 22050)
50
+
51
+ # Play audio
52
+ st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
53
+ os.remove(f"cache_sound/{input_text}.wav")
54
+ st.caption("Generated Voice")
55
+
56
+ st.set_page_config(
57
+ page_title = "소신 Team Demo",
58
+ page_icon = "🔉",
59
+ )
60
+
61
+ init_session_state()
62
+
63
+ centered_text("🔉 소신 Team Demo")
64
+ centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
65
+ st.write(" ")
66
+
67
+ mode = "p"
68
+ st.markdown(
69
+ f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie.&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; The voice \"KSS\" is traind 3 times \"은식\" is finetuned from \"KSS\" for 3 times &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
70
+ unsafe_allow_html = True
71
+ )
72
+
73
+ st.write(" ")
74
+ st.write(" ")
75
+ col1, col2 = st.columns(2)
76
+
77
+ with col1:
78
+ input_text = st.text_input(
79
+ "한글로만 입력해주세요",
80
+ value = "딥러닝은 정말 재밌어!",
81
+ )
82
+ with col2:
83
+ model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "은식"], index = 1)
84
+ if model_variant != st.session_state.model_variant:
85
+ # Update variant choice
86
+ update_session_state("model_variant", model_variant)
87
+ # Re-load model
88
+ update_model()
89
+
90
+ button_gen = st.button("Generate Voice")
91
+ if button_gen == True:
92
+ generate_voice(input_text)
93
+
94
+
95
+ class TTS:
96
+ def __init__(self, model_variant):
97
+ self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
98
+ self.voicegenerator = GAN_model()
99
+ if model_variant == '은식':
100
+ last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
101
+ check_point = torch.load(last_chpt1)
102
+ self.flowgenerator.load_state_dict(check_point['generator'])
103
+ self.flowgenerator.decoder.skip()
104
+ self.flowgenerator.eval()
105
+ if model_variant == '은식':
106
+ last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
107
+ check_point = torch.load(last_chpt2)
108
+ self.voicegenerator.load_state_dict(check_point['gen_model'])
109
+ self.voicegenerator.eval()
110
+ self.voicegenerator.remove_weight_norm()
111
+
112
+ def inference(self, input_text):
113
+ x = text_to_sequence(sentence)
114
+ filters = '([.,!?])'
115
+ sentence = re.sub(re.compile(filters), '', text)
116
+ x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
117
+ x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
118
+
119
+ with torch.no_grad():
120
+ noise_scale = .667
121
+ length_scale = 1.0
122
+ (y_gen_tst, *_), *_, (attn_gen, *_) = flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
123
+ y = voicegenerator(y_gen_tst)
124
+ audio = y.squeeze() * MAX_WAV_VALUE
125
+ audio = audio.cpu().numpy().astype('int16')
126
+
127
+ output_file = os.path.join(out_dir, 'gen_'+text[:3]+'.wav')
128
+ write(output_file, 22050, audio)
129
+ print(f'{text} is stored in {out_dir}')
130
+
131
+ return voice
132
+ plot_mel(y_gen_tst[0].data.cpu().numpy())
133
+ plot_alignment(attn_gen[0,0].data.cpu().numpy(), sequence_to_text(x[0].data.cpu().numpy()))
134
+ ipd.display(fig1,fig2)
135
+ ipd.Audio(filename=output_file)
commons.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from module import *
4
+ import numpy as np
5
+ import math
6
+
7
+
8
+ def sequence_mask(length, max_length=None):
9
+ if max_length is None:
10
+ max_length = max(length)
11
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
12
+ return x.unsqueeze(0) < length.unsqueeze(1)
13
+
14
+
15
+ def maximum_path(value, mask, max_neg_val=-np.inf):
16
+ """ Numpy-friendly version. It's about 4 times faster than torch version.
17
+ value: [b, t_x, t_y]
18
+ mask: [b, t_x, t_y]
19
+ """
20
+ value = value * mask
21
+
22
+ device = value.device
23
+ dtype = value.dtype
24
+ value = value.cpu().detach().numpy()
25
+ mask = mask.cpu().detach().numpy().astype(np.bool)
26
+
27
+ b, t_x, t_y = value.shape
28
+ direction = np.zeros(value.shape, dtype=np.int64)
29
+ v = np.zeros((b, t_x), dtype=np.float32)
30
+ x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
31
+
32
+ for j in range(t_y):
33
+ v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[:, :-1]
34
+ v1 = v
35
+ max_mask = (v1 >= v0)
36
+ v_max = np.where(max_mask, v1, v0)
37
+ direction[:, :, j] = max_mask
38
+
39
+ index_mask = (x_range <= j)
40
+ v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
41
+ direction = np.where(mask, direction, 1)
42
+
43
+ path = np.zeros(value.shape, dtype=np.float32)
44
+ index = mask[:, :, 0].sum(1).astype(np.int64) - 1
45
+ index_range = np.arange(b)
46
+
47
+
48
+ for j in reversed(range(t_y)):
49
+ path[index_range, index, j] = 1
50
+ index = index + direction[index_range, index, j] - 1
51
+
52
+
53
+ path = path * mask.astype(np.float32)
54
+ path = torch.from_numpy(path).to(device=device, dtype=dtype)
55
+ return path
56
+
57
+
58
+ def generate_path(duration, mask):
59
+ """
60
+ duration: [b, t_x]
61
+ mask: [b, t_x, t_y]
62
+ """
63
+ device = duration.device
64
+
65
+ b, t_x, t_y = mask.shape
66
+ cum_duration = torch.cumsum(duration, 1)
67
+ path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
68
+
69
+ cum_duration_flat = cum_duration.view(b * t_x)
70
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
71
+ path = path.view(b, t_x, t_y)
72
+ path = path - F.pad(path, (0, 0, 1, 0, 0, 0))[:, :-1]
73
+ path = path * mask
74
+ return path
75
+
76
+ def mle_loss(z, m, logs, logdet, mask):
77
+ # neg normal likelihood w/o the constant term
78
+ l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2))
79
+ l = l - torch.sum(logdet) # log jacobian determinant
80
+ # averaging across batch, channel and time axes
81
+ l = l / torch.sum(torch.ones_like(z) * mask)
82
+ l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
83
+ return l
84
+
85
+
86
+ def duration_loss(logw, logw_, lengths):
87
+ l = torch.sum((logw - logw_)**2) / torch.sum(lengths)
88
+ return l
89
+
90
+
91
+ class AttrDict(dict):
92
+ def __init__(self, *args, **kwargs):
93
+ super(AttrDict, self).__init__(*args, **kwargs)
94
+ self.__dict__ = self
95
+
96
+ def GAN_Loss_Generator(gen_outputs):
97
+ """
98
+ gen_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
99
+ """
100
+ loss = 0
101
+ for DG in gen_outputs:
102
+ loss += torch.mean((DG-1)**2)
103
+ return loss
104
+
105
+ def GAN_Loss_Discriminator(real_outputs, gen_outputs):
106
+ """
107
+ real_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
108
+ gen_outputs: (B, ?) list # MPD(len=5) 또는 MSD(len=3)의 출력
109
+ """
110
+ loss = 0
111
+ for D, DG in zip(real_outputs, gen_outputs):
112
+ loss += torch.mean((D-1)**2 + DG**2)
113
+ return loss
114
+
115
+ def Mel_Spectrogram_Loss(real_mel, gen_mel):
116
+ """
117
+ real_mel: (B, F, 80) # Dataloader로부터 가져온 mel-spectrogram
118
+ gen_mel: (B, F, 80) # Generator가 생성한 waveform의 mel-spectrogram
119
+ """
120
+ loss = F.l1_loss(real_mel, gen_mel)
121
+ return 45*loss
122
+
123
+ def Feature_Matching_Loss(real_features, gen_features):
124
+ """
125
+ real_features: (?, ..., ?) list of list # MPD(len=[5, 6]) 또는 MSD(len=[3, 7])의 출력
126
+ gen_features: (?, ..., ?) list of list # MPD(len=[5, 6]) 또는 MSD(len=[3, 7])의 출력
127
+ """
128
+ loss = 0
129
+ for Ds, DGs in zip(real_features, gen_features):
130
+ for D, DG in zip(Ds, DGs):
131
+ loss += torch.mean(torch.abs(D - DG))
132
+ return 2*loss
133
+
134
+ def Final_Loss_Generator(mpd_gen_outputs, mpd_real_features, mpd_gen_features,
135
+ msd_gen_outputs, msd_real_features, msd_gen_features,
136
+ real_mel, gen_mel):
137
+ """
138
+ =====inputs=====
139
+ [:3]: MPD outputs 뒤쪽 3개
140
+ [3:6]: MSD outputs 뒤쪽 3개
141
+ [7:8]: real_mel and gen_mel
142
+ =====outputs=====
143
+ Generator_Loss
144
+ Mel_Loss
145
+ """
146
+ Gen_Adv1 = GAN_Loss_Generator(mpd_gen_outputs)
147
+ Gen_Adv2 = GAN_Loss_Generator(msd_gen_outputs)
148
+ Adv = Gen_Adv1 + Gen_Adv2
149
+ FM1 = Feature_Matching_Loss(mpd_real_features, mpd_gen_features)
150
+ FM2 = Feature_Matching_Loss(msd_real_features, msd_gen_features)
151
+ FM = FM1 + FM2
152
+ Mel_Loss = Mel_Spectrogram_Loss(real_mel, gen_mel)
153
+ Generator_Loss = Adv + FM + Mel_Loss
154
+
155
+ return Generator_Loss, Mel_Loss , Adv, FM
156
+
157
+ def Final_Loss_Discriminator(mpd_real_outputs, mpd_gen_outputs,
158
+ msd_real_outputs, msd_gen_outputs):
159
+ """
160
+ =====inputs=====
161
+ [:2]: MPD outputs 앞쪽 2개
162
+ [2:4]: MSD outputs 앞쪽 2개
163
+ =====outputs=====
164
+ Discriminator_Loss
165
+ """
166
+ Disc_Adv1 = GAN_Loss_Discriminator(mpd_real_outputs, mpd_gen_outputs)
167
+ Disc_Adv2 = GAN_Loss_Discriminator(msd_real_outputs, msd_gen_outputs)
168
+ Discriminator_Loss = Disc_Adv1 + Disc_Adv2
169
+
170
+ return Discriminator_Loss
171
+
172
+ class Adam():
173
+ def __init__(self, params, scheduler, dim_model, warmup_steps=4000, lr=1e0, betas=(0.9, 0.98), eps=1e-9):
174
+ self.params = params
175
+ self.scheduler = scheduler
176
+ self.dim_model = dim_model
177
+ self.warmup_steps = warmup_steps
178
+ self.lr = lr
179
+ self.betas = betas
180
+ self.eps = eps
181
+ self.step_num = 1
182
+ self.cur_lr = lr * self._get_lr_scale()
183
+
184
+ self._optim = torch.optim.Adam(params, lr=self.cur_lr, betas=betas, eps=eps)
185
+ self.param_groups = self._optim.param_groups
186
+
187
+ def _get_lr_scale(self):
188
+ if self.scheduler == "noam":
189
+ return np.power(self.dim_model, -0.5) * np.min([np.power(self.step_num, -0.5), self.step_num * np.power(self.warmup_steps, -1.5)])
190
+ else:
191
+ return 1
192
+
193
+ def _update_learning_rate(self):
194
+ self.step_num += 1
195
+ if self.scheduler == "noam":
196
+ self.cur_lr = self.lr * self._get_lr_scale()
197
+ for param_group in self._optim.param_groups:
198
+ param_group['lr'] = self.cur_lr
199
+
200
+ def get_lr(self):
201
+ return self.cur_lr
202
+
203
+ def step(self):
204
+ self._optim.step()
205
+ self._update_learning_rate()
206
+
207
+ def zero_grad(self):
208
+ self._optim.zero_grad()
209
+
210
+ def load_state_dict(self, d):
211
+ self._optim.load_state_dict(d)
212
+
213
+ def state_dict(self):
214
+ return self._optim.state_dict()
datautils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from jamo import hangul_to_jamo
2
+ import librosa
3
+ import torch
4
+
5
+ sample_rate = 22050
6
+ preemphasis = 0.97
7
+ n_fft = 1024
8
+ hop_length = 256
9
+ win_length = 1024
10
+ ref_db = 20
11
+ max_db = 100
12
+ mel_dim = 80
13
+
14
+ PAD = '_'
15
+ EOS = '~'
16
+ SPACE = ' '
17
+
18
+ JAMO_LEADS = "".join([chr(_) for _ in range(0x1100, 0x1113)])
19
+ JAMO_VOWELS = "".join([chr(_) for _ in range(0x1161, 0x1176)])
20
+ JAMO_TAILS = "".join([chr(_) for _ in range(0x11A8, 0x11C3)])
21
+
22
+ VALID_CHARS = JAMO_LEADS + JAMO_VOWELS + JAMO_TAILS + SPACE
23
+ symbols = PAD + EOS + VALID_CHARS
24
+
25
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
26
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
27
+
28
+ # text를 초성, 중성, 종성으로 분리하여 id로 반환하는 함수
29
+ def text_to_sequence(text):
30
+ sequence = []
31
+ if not 0x1100 <= ord(text[0]) <= 0x1113:
32
+ text = ''.join(list(hangul_to_jamo(text)))
33
+ for s in text:
34
+ sequence.append(_symbol_to_id[s])
35
+ sequence.append(_symbol_to_id['~'])
36
+ return sequence
37
+
38
+ def sequence_to_text(sequence):
39
+ result = ''
40
+ for symbol_id in sequence:
41
+ if symbol_id in _id_to_symbol:
42
+ s = _id_to_symbol[symbol_id]
43
+ result += s
44
+ return result.replace('}{', ' ')
45
+
46
+ def mel_spectrogram(y, n_fft=1024, num_mels=80, sampling_rate=22050, hop_size=256, win_size=1024, fmin=0, fmax=8000, center=False):
47
+ """
48
+ if torch.min(y) < -1.:
49
+ print('min value is ', torch.min(y))
50
+ if torch.max(y) > 1.:
51
+ print('max value is ', torch.max(y))
52
+ """
53
+
54
+ mel = librosa.filters.mel(sampling_rate, n_fft, num_mels, fmin, fmax)
55
+
56
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
57
+ y = y.squeeze(1)
58
+
59
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=torch.hann_window(win_size).to(y.device),
60
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
61
+
62
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
63
+
64
+ spec = torch.matmul(torch.from_numpy(mel).float().to(y.device), spec)
65
+ spec = torch.log(torch.clamp(spec, min=1e-5) * 1)
66
+
67
+ return spec
log/1038_eunsik_01/Glow_TTS_00289602.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c84ba863d822d97db5c00dc4e69fcdf15d9040c456cab6503179bb220748cee
3
+ size 279930587
log/1038_eunsik_01/HiFI_GAN_00257000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2db9dd4f1fa06c40d98a19ac4e1740e390d66de5cef3138d670a29d8f917bddb
3
+ size 421187547
model.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from module import *
5
+ from commons import *
6
+ import math
7
+
8
+
9
+ class Generator(nn.Module):
10
+ def __init__(self, n_vocab, h_c, f_c, f_c_dp, out_c, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6):
11
+ super().__init__()
12
+ self.encoder = Encoder(n_vocab, out_c, h_c, f_c, f_c_dp, heads= heads, layers = layers_enc, k_s = k_s)
13
+ self.decoder = Decoder(in_c = out_c, hi_c = h_c, k_s = k_s_dec)
14
+
15
+ def forward(self, x, x_lengths, y=None, y_lengths=None, gen = False, noise_scale=1., length_scale=1.):
16
+ x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths)
17
+ if gen:
18
+ w = torch.exp(logw) * x_mask * length_scale
19
+ w_ceil = torch.ceil(w)
20
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
21
+ y_max_length = None
22
+ y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
23
+
24
+ z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
25
+ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
26
+ attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
27
+
28
+ z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
29
+ z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
30
+ logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
31
+
32
+ z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask
33
+ y, logdet = self.decoder(z, z_mask, reverse=True)
34
+ return (y, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)
35
+
36
+ else:
37
+ y_max_length = y.size(2)
38
+ y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
39
+ z_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
40
+ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
41
+
42
+ z, logdet = self.decoder(y, z_mask, reverse=False)
43
+ with torch.no_grad():
44
+ x_s_sq_r = torch.exp(-2 * x_logs)
45
+ logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(-1) # [b, t, 1]
46
+ logp2 = torch.matmul(x_s_sq_r.transpose(1,2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
47
+ logp3 = torch.matmul((x_m * x_s_sq_r).transpose(1,2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
48
+ logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(-1) # [b, t, 1]
49
+ logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
50
+
51
+ attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
52
+ z_m = torch.matmul(attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
53
+ z_logs = torch.matmul(attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
54
+ logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
55
+
56
+ return (z, z_m, z_logs, logdet, z_mask), (x_m, x_logs, x_mask), (attn, logw, logw_)
57
+
58
+ def preprocess(self, y, y_lengths, y_max_length):
59
+ if y_max_length is not None:
60
+ y_max_length = (y_max_length // 2) * 2
61
+ y = y[:,:,:y_max_length]
62
+ y_lengths = (y_lengths // 2) * 2
63
+ return y, y_lengths, y_max_length
64
+
65
+
66
+ class Encoder(nn.Module):
67
+ def __init__(self, n_vocab, out_c, h_c, f_c, f_c_dp, heads, layers, k_s, p=0.1, mean_only = True):
68
+ super().__init__()
69
+ self.h_c = h_c
70
+ self.mean_only = mean_only
71
+ self.emb = nn.Embedding(n_vocab, h_c)
72
+ nn.init.normal_(self.emb.weight, 0.0, h_c**(-0.5))
73
+ self.prenet = Prenet(in_c = h_c, hi_c = h_c, out_c = h_c, k_s = 5)
74
+ self.drop = nn.Dropout(p=p)
75
+ self.atten_layers = nn.ModuleList()
76
+ self.norm_layers = nn.ModuleList()
77
+ self.ffn_layers = nn.ModuleList()
78
+ for i in range(layers):
79
+ self.atten_layers.append(MultiheadAttention(h_c, h_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None))
80
+ self.norm_layers.extend([Layernorm(h_c), Layernorm(h_c)])
81
+ self.ffn_layers.append(FFN(h_c, f_c, k_s, p))
82
+ self.proj_m = nn.Conv1d(h_c, out_c, 1)
83
+ if not mean_only:
84
+ self.proj_s = nn.Conv1d(h_c, out_c, 1)
85
+ self.proj_w = DurationPredictor(h_c, f_c_dp, k_s, p=p)
86
+
87
+ def forward(self,x, x_length):
88
+ x = self.emb(x) * torch.sqrt(torch.tensor(self.h_c)) # [b,t,h]
89
+ x = torch.transpose(x, 1, -1) # [b,h,t]
90
+ x_mask = torch.unsqueeze(sequence_mask(x_length, x.size(2)), 1).to(x.dtype)
91
+
92
+ x = self.prenet(x, x_mask)
93
+ atten_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
94
+ for i in range(len(self.atten_layers)):
95
+ x = x * x_mask
96
+ y = self.drop(self.atten_layers[i](x, atten_mask))
97
+ x = self.norm_layers[2*i](x+y)
98
+
99
+ y = self.drop(self.ffn_layers[i](x, x_mask))
100
+ x = self.norm_layers[2*i+1](x+y)
101
+ x = x*x_mask
102
+
103
+ x_m = self.proj_m(x)
104
+ if not self.mean_only:
105
+ x_logs = self.proj_m(x)
106
+ else:
107
+ x_logs = torch.zeros_like(x_m)
108
+ logw = self.proj_w(x.detach(), x_mask)
109
+
110
+ return x_m, x_logs, logw, x_mask
111
+
112
+ class Decoder(nn.Module):
113
+ def __init__(self, in_c, hi_c, k_s, d_l =1 , blocks = 12, splits = 4,):
114
+ super().__init__()
115
+ self.flows = nn.ModuleList()
116
+ for _ in range(blocks):
117
+ self.flows.extend([ActNorm(in_c*2), InvConvNear(splits = splits), Couplinglayer(in_c*2, hi_c, k_s, d_l = d_l)])
118
+
119
+ def forward(self, x, x_mask = None, reverse = False):
120
+ if not reverse:
121
+ flows = self.flows
122
+ tot_logdet = 0
123
+ else:
124
+ flows = reversed(self.flows)
125
+ tot_logdet = None
126
+
127
+ b, c, t = x.shape
128
+ t = t - t%2
129
+ if x_mask is None:
130
+ mask = torch.ones(b,1,t//2)
131
+ else:
132
+ mask = x_mask[:,:,1::2]
133
+ x = x[:,:,:t].reshape(b, c, t//2, 2).transpose(2,3).contiguous().reshape(b,2*c,t//2) * mask # [b, 2c, t/2]
134
+ for f in flows:
135
+ x, logdet = f(x, mask, reverse = reverse)
136
+ if not reverse:
137
+ tot_logdet = tot_logdet + logdet
138
+ if x_mask is None:
139
+ mask = torch.ones(b,1,t)
140
+ else:
141
+ mask = x_mask[:,:,:t]
142
+ x = x.reshape(b,c,2,t//2).transpose(2,3).contiguous().reshape(b,c,t) * mask # [b, c, t]
143
+ return x, tot_logdet
144
+
145
+ def skip(self):
146
+ for f in self.flows:
147
+ f.skip()
148
+
149
+ def ddi_init(self):
150
+ for i, f in enumerate(self.flows):
151
+ if i % 3 == 0:
152
+ f.set_ddi()
153
+
154
+
module.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ ######################################### encoder ##############################################
6
+
7
+ class Layernorm(nn.Module):
8
+ def __init__(self, channels):
9
+ super().__init__()
10
+ self.gamma = nn.Parameter(torch.ones(1, channels))
11
+ self.beta = nn.Parameter(torch.zeros(1, channels))
12
+
13
+ def forward(self, x):
14
+ m = torch.mean(x, dim = 1, keepdim = True)
15
+ v = torch.mean((x-m)**2, dim = 1, keepdim = True)
16
+ x = (x - m) * torch.rsqrt(v + 1e-4) # normarlization
17
+ n = len(x.shape)
18
+ shape = [1, -1] + [1]*(n-2)
19
+ x = x*self.gamma.reshape(*shape) + self.beta.reshape(*shape)
20
+ return x
21
+
22
+ class Prenet(nn.Module):
23
+ def __init__(self, in_c, hi_c, out_c, k_s = 5, layers =3, p = 0.05):
24
+ super().__init__()
25
+ self.crn = nn.ModuleList()
26
+ self.crn.extend([nn.Conv1d(in_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
27
+ self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
28
+ self.crn.extend([nn.Conv1d(hi_c, hi_c, k_s, padding = k_s//2), Layernorm(hi_c), nn.ReLU(), nn.Dropout(p=p)])
29
+
30
+ self.proj = nn.Conv1d(hi_c, out_c, 1)
31
+ self.proj.weight.data.zero_()
32
+ self.proj.bias.data.zero_()
33
+
34
+ def forward(self, start, x_mask=1):
35
+ x = start
36
+ for layer in self.crn:
37
+ x = layer(x) # [b. c. t]
38
+ x = x * x_mask
39
+ x = self.proj(x) + start # [b. c. t]
40
+ end = x * x_mask
41
+ return end # [b. c. t]
42
+
43
+ class MultiheadAttention(nn.Module):
44
+ def __init__(self, c, out_c, heads, window_size=4, heads_share=True, p=0.1, block_length=None,):
45
+ super().__init__()
46
+
47
+ self.k = c // heads
48
+ self.window_size = window_size
49
+ self.proj_q = nn.Conv1d(c,c,1)
50
+ self.proj_k = nn.Conv1d(c,c,1)
51
+ self.proj_v = nn.Conv1d(c,c,1)
52
+
53
+ nn.init.xavier_uniform_(self.proj_q.weight)
54
+ nn.init.xavier_uniform_(self.proj_k.weight)
55
+ nn.init.xavier_uniform_(self.proj_v.weight)
56
+
57
+ n_heads_rel = 1 if heads_share else heads
58
+ self.d_k = (self.k)**(-0.5)
59
+ self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k)
60
+ self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size*2 +1, self.k) * self.d_k)
61
+
62
+ self.conv_o = nn.Conv1d(c, out_c, 1)
63
+ self.drop = nn.Dropout(p=p)
64
+
65
+ def forward(self, x, attn_mask=None):
66
+ query, key, value = self.proj_q(x), self.proj_k(x), self.proj_v(x)
67
+ b, c, t = query.shape
68
+ h, k = c // self.k, self.k
69
+
70
+ query = query.reshape(b,h,k,t)
71
+ key = key.reshape(b,h,k,t)
72
+ value = value.reshape(b,h,k,t)
73
+
74
+ matrix = self.get_relative_matrix(self.emb_rel_k, t)
75
+ rel_logit = torch.matmul(matrix.unsqueeze(0), query) # [1,1,2t-1,k] * [b,h,k,t] = [b,h,2t-1,t]
76
+ abs_logit = self.rel_to_abs(rel_logit.transpose(2,3))
77
+ local_score = abs_logit * self.d_k
78
+
79
+ score = torch.matmul(query.transpose(2,3), key) * self.d_k + local_score
80
+ if attn_mask is not None:
81
+ score = score.masked_fill(attn_mask == 0, -1e4)
82
+
83
+ align = F.softmax(score, dim = -1)
84
+ atten = self.drop(align)
85
+ self.atten = atten
86
+
87
+ matrix = self.get_relative_matrix(self.emb_rel_v, t).transpose(1,2) # [1,k,2t-1]
88
+ weight = self.abs_to_rel(atten).transpose(2,3) # [b,h,2t-1,t]
89
+ output = torch.matmul(value, atten) + torch.matmul(matrix.unsqueeze(0), weight) # [b,h,k,t]
90
+ x = self.conv_o(output.contiguous().reshape(b,c,t))
91
+ return x
92
+
93
+ def get_relative_matrix(self, emb_rel_k, t):
94
+ s = self.window_size
95
+ pad_size = max(t - s - 1, 0)
96
+ start = max(s+1-t, 0)
97
+ emb_rel_k = F.pad(emb_rel_k, (0,0, pad_size, pad_size))
98
+ return emb_rel_k[:,start:start+2*t+1]
99
+
100
+ def rel_to_abs(self, x):
101
+ b,h,t,_= x.shape
102
+ x = F.pad(x, (0,1)).reshape(b,h,2*t*t)
103
+ x = F.pad(x, (0,t-1)).reshape(b,h,t+1, 2*t-1)[:,:,:t,t-1:]
104
+ return x
105
+
106
+ def abs_to_rel(self, x):
107
+ b,h,t,t = x.shape
108
+ x = F.pad(x, (0, t-1)).reshape(b,h,2*t*t-t)
109
+ x = F.pad(x, (t,0)).reshape(b,h,t,2*t)[:,:,:,1:]
110
+ return x
111
+
112
+ class FFN(nn.Module):
113
+ def __init__(self, h_c, f_c, k_s, p = 0.1):
114
+ super().__init__()
115
+ self.conv1 = nn.Conv1d(h_c, f_c, k_s, padding=k_s//2)
116
+ self.conv2 = nn.Conv1d(f_c, h_c, k_s, padding=k_s//2)
117
+ self.drop = nn.Dropout(p=p)
118
+ def forward(self, x, x_mask = None):
119
+ x = self.conv2(self.drop(F.relu(self.conv1(x*x_mask)))*x_mask)
120
+ return x * x_mask
121
+
122
+ class DurationPredictor(nn.Module):
123
+ def __init__(self, in_c, f_c, k_s, p=0.1):
124
+ super().__init__()
125
+ self.block1 = nn.Sequential(nn.Conv1d(in_c, f_c, k_s, padding=k_s//2),
126
+ nn.ReLU(),
127
+ Layernorm(f_c),
128
+ nn.Dropout(p=p))
129
+ self.block2 = nn.Sequential(nn.Conv1d(f_c, f_c, k_s, padding=k_s//2),
130
+ nn.ReLU(),
131
+ Layernorm(f_c),
132
+ nn.Dropout(p=p))
133
+ self.proj = nn.Conv1d(f_c, 1, 1)
134
+ def forward(self, x, x_mask):
135
+ x = self.block1(x * x_mask)
136
+ x = self.block2(x * x_mask)
137
+ x = self.proj(x * x_mask)
138
+ return x * x_mask
139
+
140
+ ######################################### decoder ##############################################
141
+ # static file system(reasoning the type of tensor), optimizing computation graph, complie before functioning >> to accelate the speed
142
+
143
+ @torch.jit.script
144
+ def fuse_tan_sig_add(x:torch.Tensor, mid:int) -> torch.Tensor:
145
+ a, b = x[:, :mid, :], x[:, mid:, :]
146
+ return torch.sigmoid(a) * torch.tanh(b)
147
+
148
+ class WN(nn.Module): # non-casual wavenet without dilation
149
+ def __init__(self, hi_c, k_s, d_l = 1, layers = 3, p=0.05):
150
+ super().__init__()
151
+ self.hi_c = hi_c
152
+ self.resblocks=nn.ModuleList()
153
+ self.skipblocks=nn.ModuleList()
154
+ self.drop = nn.Dropout(p=p)
155
+ for _ in range(layers):
156
+ res_layer = nn.Conv1d(hi_c, 2*hi_c, k_s, dilation=d_l, padding=k_s//2)
157
+ res_layer = nn.utils.weight_norm(res_layer, name = 'weight')
158
+ self.resblocks.append(res_layer)
159
+ if _ ==2:
160
+ skip_layer = nn.Conv1d(hi_c, hi_c, 1) # last layer
161
+ else:
162
+ skip_layer = nn.Conv1d(hi_c, 2*hi_c, 1)
163
+ skip_layer = nn.utils.weight_norm(skip_layer, name = 'weight')
164
+ self.skipblocks.append(skip_layer)
165
+
166
+ def forward(self, x, x_mask = None):
167
+ mid = self.hi_c
168
+ end = torch.zeros_like(x, dtype=x.dtype)
169
+ for i in range(len(self.resblocks)):
170
+ x = self.drop(self.resblocks[i](x)) # [b, 2c, t]
171
+ x = fuse_tan_sig_add(x, mid) # [b, c, t]
172
+ y = self.skipblocks[i](x)
173
+ if i == 2:
174
+ end = end + y # last layer
175
+ else:
176
+ x = (x + y[:, :mid, :]) * x_mask
177
+ end = end + y[:, mid:, :]
178
+ return end * x_mask
179
+
180
+ def skip(self):
181
+ for layer1, layer2 in zip(self.resblocks, self.skipblocks):
182
+ nn.utils.remove_weight_norm(layer1)
183
+ nn.utils.remove_weight_norm(layer2)
184
+
185
+ class Couplinglayer(nn.Module):
186
+ def __init__(self, in_c, hi_c, k_s, d_l = 1):
187
+ super().__init__()
188
+ s_proj = nn.Conv1d(in_c//2, hi_c, 1)
189
+ self.start = nn.utils.weight_norm(s_proj, name = 'weight')
190
+ # Initializing last layer to 0 makes the affine coupling layers
191
+ # do nothing at first. It helps to stabilze training. from glow paper
192
+ self.end = nn.Conv1d(hi_c, in_c, 1)
193
+ self.end.weight.data.zero_()
194
+ self.end.bias.data.zero_()
195
+ self.wn = WN(hi_c, k_s, d_l)
196
+
197
+ # y = x * logs + t
198
+ def forward(self, x, x_mask=None, reverse = False):
199
+ if x_mask is None:
200
+ x_mask = 1
201
+ mid = x.shape[1]//2 # divide channels by 2
202
+ x_0, x_1 = x[:, :mid, :], x[:, mid:, :]
203
+ z_1 = self.end(self.wn(self.start(x_1) * x_mask, x_mask))
204
+ logs, t = z_1[:,mid:,:], z_1[:, :mid, :]
205
+ if reverse:
206
+ x_0 = torch.exp(-logs)*(x_0 - t) * x_mask
207
+ logdet = None
208
+ else :
209
+ x_0 = torch.exp(logs + 1e-4) * x_0 + t
210
+ logdet = torch.sum(logs * x_mask, [1,2]) # sum(log(s))
211
+ z = torch.cat([x_0, x_1], dim = 1)
212
+ return z, logdet
213
+ def skip(self):
214
+ self.wn.skip()
215
+
216
+ class InvConvNear(nn.Module):
217
+ def __init__(self, splits = 4):
218
+ super().__init__()
219
+ self.splits = splits
220
+ w_init = torch.linalg.qr(torch.randn((splits, splits)).normal_())[0] # othonormal vector matrix
221
+ if torch.det(w_init) < 0:
222
+ w_init[0,:] = -w_init[0,:]
223
+ self.weight = nn.Parameter(w_init)
224
+
225
+ def forward(self, x, x_mask=None, reverse = False):
226
+ b, c, t = x.shape
227
+
228
+ if x_mask is None:
229
+ x_mask = 1
230
+ x_len = torch.ones(b) * t # [b]
231
+ else:
232
+ x_len = torch.sum(x_mask, [1,2])
233
+
234
+ s = self.splits
235
+ x = x.reshape(b, 2, c//s, s//2, t) # split channels into 2 groups
236
+ x = x.permute(0,1,3,2,4).contiguous().reshape(b, s, c//s, t)
237
+
238
+ if reverse:
239
+ if hasattr(self, "weight_inv"):
240
+ weight = self.weight_inv
241
+ weight = torch.inverse(self.weight).to(dtype=self.weight.dtype)
242
+ logdet = None
243
+ else:
244
+ weight = self.weight
245
+ logdet = torch.logdet(weight) * (c//s) * x_len # h*w*log(det(W)) since there's no necesserity for decomposition
246
+
247
+ weight = weight.unsqueeze(-1).unsqueeze(-1)
248
+ z = F.conv2d(x, weight) # z = matmul(weight, x_i,j) for i,j in h = c//s, w = t
249
+
250
+ z = z.reshape(b, 2, s//2, c//s, t).permute(0,1,3,2,4).contiguous().reshape(b, c, t) * x_mask
251
+ return z, logdet
252
+
253
+ def skip(self):
254
+ self.weigth_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
255
+
256
+
257
+ class ActNorm(nn.Module):
258
+ def __init__(self, hi_c, ddi = False): # data dependent initialization
259
+ super().__init__()
260
+ self.logs = nn.Parameter(torch.zeros(1, hi_c, 1))
261
+ self.bias = nn.Parameter(torch.zeros(1, hi_c, 1))
262
+ self.ddi = ddi
263
+
264
+ def forward(self, x, x_mask = None, reverse = False):
265
+ b, _, t = x.shape
266
+ if x_mask is None:
267
+ x_mask = torch.ones(b,1,t).to(device= x.device, dtype = x.dtype)
268
+ x_len = torch.sum(x_mask, [1, 2])
269
+ if self.ddi:
270
+ self.initialize(x, x_mask)
271
+ self.ddi = False
272
+ # y = exp(logs) * x + bias > normalization in channel dim
273
+ if reverse:
274
+ z = (x - self.bias) * torch.exp(-self.logs) * x_mask
275
+ logdet = None
276
+ else:
277
+ z = (torch.exp(self.logs) * x + self.bias) * x_mask
278
+ logdet = torch.sum(self.logs, [1,2])* x_len
279
+ return z, logdet
280
+
281
+ def initialize(self, x, x_mask):
282
+ with torch.no_grad():
283
+ n = torch.sum(x_mask, [0,2])
284
+ m = torch.sum(x * x_mask, [0,2])/n
285
+ m_s = torch.sum(x * x * x_mask, [0,2])/n
286
+ v = m_s - m**2
287
+ logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
288
+
289
+ init_bias = (-m/torch.exp(-logs)).reshape(*self.bias.shape).to(dtype = self.bias.dtype) # -m/s
290
+ init_logs = (-logs).reshape(*self.logs.shape).to(dtype = self.logs.dtype) # -logs
291
+
292
+ self.bias.data.copy_(init_bias)
293
+ self.logs.data.copy_(init_logs)
294
+
295
+ def set_ddi(self):
296
+ self.ddi = True
297
+
298
+ def skip(self):
299
+ pass
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==1.4.0
2
+ numpy==1.17.4
3
+ librosa==0.7.2
4
+ scipy==1.4.1
5
+ tensorboard==2.0
6
+ soundfile==0.10.3.post1
7
+ matplotlib==3.1.3