uzdzn commited on
Commit
2c7b92a
1 Parent(s): 7dc05d2

Upload 7 files

Browse files
Files changed (7) hide show
  1. decoder.py +345 -0
  2. decoder_base.py +249 -0
  3. demo_interface.py +18 -0
  4. inference.py +48 -0
  5. model-best.pt +3 -0
  6. p225_007_mic1.npy +3 -0
  7. requirements.txt +85 -0
decoder.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
5
+
6
+ class CustomLSTM(nn.Module):
7
+ def __init__(self, input_sz, hidden_sz):
8
+ super().__init__()
9
+ self.input_sz = input_sz
10
+ self.hidden_size = hidden_sz
11
+ self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
12
+ self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
13
+ self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
14
+ self.init_weights()
15
+
16
+ def init_weights(self):
17
+ stdv = 1.0 / math.sqrt(self.hidden_size)
18
+ for weight in self.parameters():
19
+ weight.data.uniform_(-stdv, stdv)
20
+
21
+ def forward(self, x,
22
+ init_states=None):
23
+ """Assumes x is of shape (batch, sequence, feature)"""
24
+ #print(type(x))
25
+ #print(x.shape)
26
+ bs, seq_sz, _ = x.size()
27
+ hidden_seq = []
28
+ if init_states is None:
29
+ h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
30
+ torch.zeros(bs, self.hidden_size).to(x.device))
31
+ else:
32
+ h_t, c_t = init_states
33
+
34
+ HS = self.hidden_size
35
+ for t in range(seq_sz):
36
+ x_t = x[:, t, :]
37
+ # batch the computations into a single matrix multiplication
38
+ gates = x_t @ self.W + h_t @ self.U + self.bias
39
+ i_t, f_t, g_t, o_t = (
40
+ torch.sigmoid(gates[:, :HS]), # input
41
+ torch.sigmoid(gates[:, HS:HS*2]), # forget
42
+ torch.tanh(gates[:, HS*2:HS*3]),
43
+ torch.sigmoid(gates[:, HS*3:]), # output
44
+ )
45
+ c_t = f_t * c_t + i_t * g_t
46
+ h_t = o_t * torch.tanh(c_t)
47
+ hidden_seq.append(h_t.unsqueeze(0))
48
+ hidden_seq = torch.cat(hidden_seq, dim=0)
49
+ # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
50
+ hidden_seq = hidden_seq.transpose(0, 1).contiguous()
51
+ return hidden_seq, (h_t, c_t)
52
+
53
+ hparams = {
54
+ 'n_mel_channels': 128, # From LogMelSpectrogram
55
+ 'postnet_embedding_dim': 512, # Common choice, adjust as needed
56
+ 'postnet_kernel_size': 5, # Common choice, adjust as needed
57
+ 'postnet_n_convolutions': 5, # Typical number of Postnet convolutions
58
+ }
59
+
60
+ class ConvNorm(torch.nn.Module):
61
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
62
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
63
+ super(ConvNorm, self).__init__()
64
+ if padding is None:
65
+ assert(kernel_size % 2 == 1)
66
+ padding = int(dilation * (kernel_size - 1) / 2)
67
+
68
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
69
+ kernel_size=kernel_size, stride=stride,
70
+ padding=padding, dilation=dilation,
71
+ bias=bias)
72
+
73
+ torch.nn.init.xavier_uniform_(
74
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
75
+
76
+ def forward(self, signal):
77
+ conv_signal = self.conv(signal)
78
+ return conv_signal
79
+
80
+ URLS = {
81
+ "hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt",
82
+ "hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt",
83
+ }
84
+
85
+
86
+ class AcousticModel(nn.Module):
87
+ def __init__(self, discrete: bool = False, upsample: bool = True, use_custom_lstm=False):
88
+ super().__init__()
89
+ # self.spk_projection = nn.Linear(512+512, 512)
90
+ self.encoder = Encoder(discrete, upsample)
91
+ self.decoder = Decoder(use_custom_lstm=use_custom_lstm)
92
+ self.postnet = Postnet(hparams) # Add this line. Ensure hparams is defined or pass explicit parameters
93
+
94
+ def forward(self, x: torch.Tensor, spk_embs, mels: torch.Tensor) -> torch.Tensor:
95
+ x = self.encoder(x)
96
+ exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
97
+ concat_x = torch.cat([x, exp_spk_embs], dim=-1)
98
+ # x = self.spk_projection(concat_x)
99
+ output = self.decoder(concat_x, mels)
100
+ postnet_output = self.postnet(output) + output
101
+ return postnet_output
102
+
103
+ #def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
104
+ # x = self.encoder(x)
105
+ # return self.decoder(x, mels)
106
+
107
+ def forward_test(self, x, spk_embs, mels):
108
+ print('x shape', x.shape)
109
+ print('se shape', spk_embs.shape)
110
+ print('mels shape', mels.shape)
111
+ x = self.encoder(x)
112
+ print('x_enc shape', x.shape)
113
+ return
114
+
115
+ @torch.inference_mode()
116
+ def generate(self, x: torch.Tensor, spk_embs) -> torch.Tensor:
117
+ x = self.encoder(x)
118
+ exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
119
+ concat_x = torch.cat([x, exp_spk_embs], dim=-1)
120
+ # x = self.spk_projection(concat_x)
121
+ mels = self.decoder.generate(concat_x)
122
+ postnet_mels = self.postnet(mels) + mels
123
+ return postnet_mels
124
+
125
+
126
+ class Encoder(nn.Module):
127
+ def __init__(self, discrete: bool = False, upsample: bool = True):
128
+ super().__init__()
129
+ self.embedding = nn.Embedding(100 + 1, 256) if discrete else None
130
+ self.prenet = PreNet(256, 256, 256)
131
+ self.convs = nn.Sequential(
132
+ nn.Conv1d(256, 512, 5, 1, 2),
133
+ nn.ReLU(),
134
+ nn.Dropout(0.3),
135
+ nn.InstanceNorm1d(512),
136
+ nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(),
137
+ nn.Dropout(0.3),
138
+ nn.Conv1d(512, 512, 5, 1, 2),
139
+ nn.ReLU(),
140
+ nn.Dropout(0.3),
141
+ nn.InstanceNorm1d(512),
142
+ nn.Conv1d(512, 512, 5, 1, 2),
143
+ nn.ReLU(),
144
+ nn.Dropout(0.3),
145
+ nn.InstanceNorm1d(512),
146
+ )
147
+
148
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
149
+ if self.embedding is not None:
150
+ x = self.embedding(x)
151
+ x = self.prenet(x)
152
+ x = self.convs(x.transpose(1, 2))
153
+ return x.transpose(1, 2)
154
+
155
+
156
+ class Decoder(nn.Module):
157
+ def __init__(self, use_custom_lstm=False):
158
+ super().__init__()
159
+ self.use_custom_lstm = use_custom_lstm
160
+ self.prenet = PreNet(128, 256, 256)
161
+ if use_custom_lstm:
162
+ self.lstm1 = CustomLSTM(1024 + 256, 1024)
163
+ self.lstm2 = CustomLSTM(1024, 1024)
164
+ self.lstm3 = CustomLSTM(1024, 1024)
165
+ else:
166
+ self.lstm1 = nn.LSTM(1024 + 256, 1024)
167
+ self.lstm2 = nn.LSTM(1024, 1024)
168
+ self.lstm3 = nn.LSTM(1024, 1024)
169
+ self.proj = nn.Linear(1024, 128, bias=False)
170
+ self.dropout = nn.Dropout(0.3)
171
+
172
+ def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
173
+ mels = self.prenet(mels)
174
+ x, _ = self.lstm1(torch.cat((x, mels), dim=-1))
175
+ x = self.dropout(x)
176
+ res = x
177
+ x, _ = self.lstm2(x)
178
+ x = self.dropout(x)
179
+ x = res + x
180
+ res = x
181
+ x, _ = self.lstm3(x)
182
+ x = self.dropout(x)
183
+ x = res + x
184
+ return self.proj(x)
185
+
186
+ @torch.inference_mode()
187
+ def generate(self, xs: torch.Tensor) -> torch.Tensor:
188
+ m = torch.zeros(xs.size(0), 128, device=xs.device)
189
+ if self.use_custom_lstm:
190
+ h1 = torch.zeros(xs.size(0), 1024, device=xs.device)
191
+ c1 = torch.zeros(xs.size(0), 1024, device=xs.device)
192
+ h2 = torch.zeros(xs.size(0), 1024, device=xs.device)
193
+ c2 = torch.zeros(xs.size(0), 1024, device=xs.device)
194
+ h3 = torch.zeros(xs.size(0), 1024, device=xs.device)
195
+ c3 = torch.zeros(xs.size(0), 1024, device=xs.device)
196
+ else:
197
+ h1 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
198
+ c1 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
199
+ h2 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
200
+ c2 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
201
+ h3 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
202
+ c3 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
203
+
204
+
205
+ mel = []
206
+ for x in torch.unbind(xs, dim=1):
207
+ m = self.prenet(m)
208
+ x = torch.cat((x, m), dim=1).unsqueeze(1)
209
+ x1, (h1, c1) = self.lstm1(x, (h1, c1))
210
+ x2, (h2, c2) = self.lstm2(x1, (h2, c2))
211
+ x = x1 + x2
212
+ x3, (h3, c3) = self.lstm3(x, (h3, c3))
213
+ x = x + x3
214
+ m = self.proj(x).squeeze(1)
215
+ mel.append(m)
216
+ return torch.stack(mel, dim=1)
217
+
218
+
219
+ class PreNet(nn.Module):
220
+ def __init__(
221
+ self,
222
+ input_size: int,
223
+ hidden_size: int,
224
+ output_size: int,
225
+ dropout: float = 0.5,
226
+ ):
227
+ super().__init__()
228
+ self.net = nn.Sequential(
229
+ nn.Linear(input_size, hidden_size),
230
+ nn.ReLU(),
231
+ nn.Dropout(dropout),
232
+ nn.Linear(hidden_size, output_size),
233
+ nn.ReLU(),
234
+ nn.Dropout(dropout),
235
+ )
236
+
237
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
238
+ return self.net(x)
239
+
240
+
241
+ def _acoustic(
242
+ name: str,
243
+ discrete: bool,
244
+ upsample: bool,
245
+ pretrained: bool = True,
246
+ progress: bool = True,
247
+ ) -> AcousticModel:
248
+ acoustic = AcousticModel(discrete, upsample)
249
+ if pretrained:
250
+ checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
251
+ consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
252
+ acoustic.load_state_dict(checkpoint["acoustic-model"])
253
+ acoustic.eval()
254
+ return acoustic
255
+
256
+
257
+ def hubert_discrete(
258
+ pretrained: bool = True,
259
+ progress: bool = True,
260
+ ) -> AcousticModel:
261
+ r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
262
+ Args:
263
+ pretrained (bool): load pretrained weights into the model
264
+ progress (bool): show progress bar when downloading model
265
+ """
266
+ return _acoustic(
267
+ "hubert-discrete",
268
+ discrete=True,
269
+ upsample=True,
270
+ pretrained=pretrained,
271
+ progress=progress,
272
+ )
273
+
274
+
275
+ def hubert_soft(
276
+ pretrained: bool = True,
277
+ progress: bool = True,
278
+ ) -> AcousticModel:
279
+ r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
280
+ Args:
281
+ pretrained (bool): load pretrained weights into the model
282
+ progress (bool): show progress bar when downloading model
283
+ """
284
+ return _acoustic(
285
+ "hubert-soft",
286
+ discrete=False,
287
+ upsample=True,
288
+ pretrained=pretrained,
289
+ progress=progress,
290
+ )
291
+
292
+ class Postnet(nn.Module):
293
+ def __init__(self, hparams):
294
+ super(Postnet, self).__init__()
295
+ self.convolutions = nn.ModuleList()
296
+
297
+ self.convolutions.append(
298
+ nn.Sequential(
299
+ ConvNorm(in_channels=hparams['n_mel_channels'], # Adjusted input channels
300
+ out_channels=hparams['postnet_embedding_dim'], # Output channels remain the same
301
+ kernel_size=hparams['postnet_kernel_size'], stride=1,
302
+ padding=int((hparams['postnet_kernel_size'] - 1) / 2), # Dynamic padding
303
+ dilation=1, bias=True, w_init_gain='tanh'),
304
+ nn.BatchNorm1d(hparams['postnet_embedding_dim'])
305
+ )
306
+ )
307
+
308
+ for i in range(1, hparams['postnet_n_convolutions'] - 1):
309
+ self.convolutions.append(
310
+ nn.Sequential(
311
+ ConvNorm(hparams['postnet_embedding_dim'],
312
+ hparams['postnet_embedding_dim'],
313
+ kernel_size=hparams['postnet_kernel_size'], stride=1,
314
+ padding=int((hparams['postnet_kernel_size'] - 1) / 2), # Dynamic padding
315
+ dilation=1, w_init_gain='tanh'),
316
+ nn.BatchNorm1d(hparams['postnet_embedding_dim'])
317
+ )
318
+ )
319
+
320
+ self.convolutions.append(
321
+ nn.Sequential(
322
+ ConvNorm(hparams['postnet_embedding_dim'], hparams['n_mel_channels'],
323
+ kernel_size=hparams['postnet_kernel_size'], stride=1,
324
+ padding=int((hparams['postnet_kernel_size'] - 1) / 2), # Dynamic padding
325
+ dilation=1, w_init_gain='linear'),
326
+ nn.BatchNorm1d(hparams['n_mel_channels'])
327
+ )
328
+ )
329
+
330
+ def forward(self, x):
331
+ #print(f"Input shape to Postnet: {x.shape}")
332
+ x = x.transpose(1, 2)
333
+ for i, conv in enumerate(self.convolutions[:-1]):
334
+ x = conv(x)
335
+ #print(f"Shape after Convolution {i+1}: {x.shape}")
336
+ x = torch.tanh(x)
337
+ x = F.dropout(x, 0.5, self.training)
338
+
339
+ # Last layer
340
+ x = self.convolutions[-1](x)
341
+ #print(f"Shape after last Convolution: {x.shape}")
342
+ x = F.dropout(x, 0.5, self.training)
343
+ x = x.transpose(1, 2)
344
+
345
+ return x
decoder_base.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
5
+
6
+ URLS = {
7
+ "hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt",
8
+ "hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt",
9
+ }
10
+
11
+ class CustomLSTM(nn.Module):
12
+ def __init__(self, input_sz, hidden_sz):
13
+ super().__init__()
14
+ self.input_sz = input_sz
15
+ self.hidden_size = hidden_sz
16
+ self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
17
+ self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
18
+ self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
19
+ self.init_weights()
20
+
21
+ def init_weights(self):
22
+ stdv = 1.0 / math.sqrt(self.hidden_size)
23
+ for weight in self.parameters():
24
+ weight.data.uniform_(-stdv, stdv)
25
+
26
+ def forward(self, x,
27
+ init_states=None):
28
+ """Assumes x is of shape (batch, sequence, feature)"""
29
+ #print(type(x))
30
+ #print(x.shape)
31
+ bs, seq_sz, _ = x.size()
32
+ hidden_seq = []
33
+ if init_states is None:
34
+ h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
35
+ torch.zeros(bs, self.hidden_size).to(x.device))
36
+ else:
37
+ h_t, c_t = init_states
38
+
39
+ HS = self.hidden_size
40
+ for t in range(seq_sz):
41
+ x_t = x[:, t, :]
42
+ # batch the computations into a single matrix multiplication
43
+ gates = x_t @ self.W + h_t @ self.U + self.bias
44
+ i_t, f_t, g_t, o_t = (
45
+ torch.sigmoid(gates[:, :HS]), # input
46
+ torch.sigmoid(gates[:, HS:HS*2]), # forget
47
+ torch.tanh(gates[:, HS*2:HS*3]),
48
+ torch.sigmoid(gates[:, HS*3:]), # output
49
+ )
50
+ c_t = f_t * c_t + i_t * g_t
51
+ h_t = o_t * torch.tanh(c_t)
52
+ hidden_seq.append(h_t.unsqueeze(0))
53
+ hidden_seq = torch.cat(hidden_seq, dim=0)
54
+ # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
55
+ hidden_seq = hidden_seq.transpose(0, 1).contiguous()
56
+ return hidden_seq, (h_t, c_t)
57
+
58
+ class AcousticModel(nn.Module):
59
+ def __init__(self, discrete: bool = False, upsample: bool = True, use_custom_lstm=False):
60
+ super().__init__()
61
+ # self.spk_projection = nn.Linear(512+512, 512)
62
+ self.encoder = Encoder(discrete, upsample)
63
+ self.decoder = Decoder(use_custom_lstm=use_custom_lstm)
64
+
65
+ def forward(self, x: torch.Tensor, spk_embs, mels: torch.Tensor) -> torch.Tensor:
66
+ x = self.encoder(x)
67
+ exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
68
+ concat_x = torch.cat([x, exp_spk_embs], dim=-1)
69
+ # x = self.spk_projection(concat_x)
70
+ return self.decoder(concat_x, mels)
71
+
72
+ #def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
73
+ # x = self.encoder(x)
74
+ # return self.decoder(x, mels)
75
+
76
+ def forward_test(self, x, spk_embs, mels):
77
+ print('x shape', x.shape)
78
+ print('se shape', spk_embs.shape)
79
+ print('mels shape', mels.shape)
80
+ x = self.encoder(x)
81
+ print('x_enc shape', x.shape)
82
+ return
83
+
84
+ @torch.inference_mode()
85
+ def generate(self, x: torch.Tensor, spk_embs) -> torch.Tensor:
86
+ x = self.encoder(x)
87
+ exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
88
+ concat_x = torch.cat([x, exp_spk_embs], dim=-1)
89
+ # x = self.spk_projection(concat_x)
90
+ return self.decoder.generate(concat_x)
91
+
92
+
93
+ class Encoder(nn.Module):
94
+ def __init__(self, discrete: bool = False, upsample: bool = True):
95
+ super().__init__()
96
+ self.embedding = nn.Embedding(100 + 1, 256) if discrete else None
97
+ self.prenet = PreNet(256, 256, 256)
98
+ self.convs = nn.Sequential(
99
+ nn.Conv1d(256, 512, 5, 1, 2),
100
+ nn.ReLU(),
101
+ nn.InstanceNorm1d(512),
102
+ nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(),
103
+ nn.Conv1d(512, 512, 5, 1, 2),
104
+ nn.ReLU(),
105
+ nn.InstanceNorm1d(512),
106
+ nn.Conv1d(512, 512, 5, 1, 2),
107
+ nn.ReLU(),
108
+ nn.InstanceNorm1d(512),
109
+ )
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ if self.embedding is not None:
113
+ x = self.embedding(x)
114
+ x = self.prenet(x)
115
+ x = self.convs(x.transpose(1, 2))
116
+ return x.transpose(1, 2)
117
+
118
+
119
+ class Decoder(nn.Module):
120
+ def __init__(self, use_custom_lstm=False):
121
+ super().__init__()
122
+ self.use_custom_lstm = use_custom_lstm
123
+ self.prenet = PreNet(128, 256, 256)
124
+ self.prenet = PreNet(128, 256, 256)
125
+ if use_custom_lstm:
126
+ self.lstm1 = CustomLSTM(1024 + 256, 768)
127
+ self.lstm2 = CustomLSTM(768, 768)
128
+ self.lstm3 = CustomLSTM(768, 768)
129
+ else:
130
+ self.lstm1 = nn.LSTM(1024 + 256, 768)
131
+ self.lstm2 = nn.LSTM(768, 768)
132
+ self.lstm3 = nn.LSTM(768, 768)
133
+ self.proj = nn.Linear(768, 128, bias=False)
134
+
135
+ def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
136
+ mels = self.prenet(mels)
137
+ x, _ = self.lstm1(torch.cat((x, mels), dim=-1))
138
+ res = x
139
+ x, _ = self.lstm2(x)
140
+ x = res + x
141
+ res = x
142
+ x, _ = self.lstm3(x)
143
+ x = res + x
144
+ return self.proj(x)
145
+
146
+ @torch.inference_mode()
147
+ def generate(self, xs: torch.Tensor) -> torch.Tensor:
148
+ m = torch.zeros(xs.size(0), 128, device=xs.device)
149
+ if not self.use_custom_lstm:
150
+ h1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
151
+ c1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
152
+ h2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
153
+ c2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
154
+ h3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
155
+ c3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
156
+ else:
157
+ h1 = torch.zeros(xs.size(0), 768, device=xs.device)
158
+ c1 = torch.zeros(xs.size(0), 768, device=xs.device)
159
+ h2 = torch.zeros(xs.size(0), 768, device=xs.device)
160
+ c2 = torch.zeros(xs.size(0), 768, device=xs.device)
161
+ h3 = torch.zeros(xs.size(0), 768, device=xs.device)
162
+ c3 = torch.zeros(xs.size(0), 768, device=xs.device)
163
+
164
+ mel = []
165
+ for x in torch.unbind(xs, dim=1):
166
+ m = self.prenet(m)
167
+ x = torch.cat((x, m), dim=1).unsqueeze(1)
168
+ x1, (h1, c1) = self.lstm1(x, (h1, c1))
169
+ x2, (h2, c2) = self.lstm2(x1, (h2, c2))
170
+ x = x1 + x2
171
+ x3, (h3, c3) = self.lstm3(x, (h3, c3))
172
+ x = x + x3
173
+ m = self.proj(x).squeeze(1)
174
+ mel.append(m)
175
+ return torch.stack(mel, dim=1)
176
+
177
+
178
+ class PreNet(nn.Module):
179
+ def __init__(
180
+ self,
181
+ input_size: int,
182
+ hidden_size: int,
183
+ output_size: int,
184
+ dropout: float = 0.5,
185
+ ):
186
+ super().__init__()
187
+ self.net = nn.Sequential(
188
+ nn.Linear(input_size, hidden_size),
189
+ nn.ReLU(),
190
+ nn.Dropout(dropout),
191
+ nn.Linear(hidden_size, output_size),
192
+ nn.ReLU(),
193
+ nn.Dropout(dropout),
194
+ )
195
+
196
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
197
+ return self.net(x)
198
+
199
+
200
+ def _acoustic(
201
+ name: str,
202
+ discrete: bool,
203
+ upsample: bool,
204
+ pretrained: bool = True,
205
+ progress: bool = True,
206
+ ) -> AcousticModel:
207
+ acoustic = AcousticModel(discrete, upsample)
208
+ if pretrained:
209
+ checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
210
+ consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
211
+ acoustic.load_state_dict(checkpoint["acoustic-model"])
212
+ acoustic.eval()
213
+ return acoustic
214
+
215
+
216
+ def hubert_discrete(
217
+ pretrained: bool = True,
218
+ progress: bool = True,
219
+ ) -> AcousticModel:
220
+ r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
221
+ Args:
222
+ pretrained (bool): load pretrained weights into the model
223
+ progress (bool): show progress bar when downloading model
224
+ """
225
+ return _acoustic(
226
+ "hubert-discrete",
227
+ discrete=True,
228
+ upsample=True,
229
+ pretrained=pretrained,
230
+ progress=progress,
231
+ )
232
+
233
+
234
+ def hubert_soft(
235
+ pretrained: bool = True,
236
+ progress: bool = True,
237
+ ) -> AcousticModel:
238
+ r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
239
+ Args:
240
+ pretrained (bool): load pretrained weights into the model
241
+ progress (bool): show progress bar when downloading model
242
+ """
243
+ return _acoustic(
244
+ "hubert-soft",
245
+ discrete=False,
246
+ upsample=True,
247
+ pretrained=pretrained,
248
+ progress=progress,
249
+ )
demo_interface.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference import InferencePipeline
3
+
4
+ i = InferencePipeline()
5
+
6
+ demo = gr.Blocks()
7
+
8
+ mic_transcribe = gr.Interface(
9
+ fn=i.voice_conversion,
10
+ inputs=gr.inputs.Audio(source="microphone", type="filepath", label="Record or upload your voice"),
11
+ outputs=gr.outputs.Audio(label="Converted Voice"),
12
+ title="Voice Conversion Demo",
13
+ description="Voice Conversion: Transform the input voice to a target voice.",
14
+ allow_flagging="never",
15
+ )
16
+
17
+ if __name__ == "__main__":
18
+ demo.launch()
inference.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+ from decoder_base import AcousticModel
5
+
6
+ class InferencePipeline():
7
+ def __init__(self):
8
+ # download hubert content encoder
9
+ self.hubert = torch.hub.load("bshall/hubert:main", "hubert_soft", trust_repo=True)#.cuda()
10
+
11
+ # initialize decoder with checkpoint
12
+ ckpts_path = 'model-best.pt'
13
+ self.model = AcousticModel()
14
+ cp = torch.load(ckpts_path, map_location=torch.device('cpu'))
15
+ self.model.load_state_dict(cp['acoustic-model'])
16
+
17
+ # download vocoder
18
+ self.hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft", trust_repo=True, map_location=torch.device('cpu'))
19
+
20
+ # load source audio
21
+ #self.source, sr = torchaudio.load("test.wav")
22
+ #self.source = torchaudio.functional.resample(self.source, sr, 16000)
23
+ #self.source = self.source.unsqueeze(0)#.cuda()
24
+
25
+ # load target speaker embedding
26
+ self.trg_spk_emb = np.load('p225_007_mic1.npy')
27
+ self.trg_spk_emb = torch.from_numpy(self.trg_spk_emb)
28
+ self.trg_spk_emb = self.trg_spk_emb.unsqueeze(0)#.cuda()
29
+
30
+ def voice_conversion(self, audio_file_path):
31
+ # run inference
32
+ self.model.eval()
33
+ with torch.inference_mode():
34
+ # Extract speech units
35
+ units = self.hubert.units(audio_file_path)
36
+ # Generate target spectrogram
37
+ mel = self.model.generate(units, self.trg_spk_emb).transpose(1, 2)
38
+ # Generate audio waveform
39
+ target = self.hifigan(mel)
40
+
41
+ # Assuming `target` is a tensor with the audio waveform
42
+ # Convert it to numpy array and save it as an output audio file
43
+ output_audio_path = "output.wav"
44
+ torchaudio.save(output_audio_path, target.cpu(), sample_rate=16000)
45
+
46
+ return output_audio_path
47
+
48
+ #torchaudio.save("output.wav", target.squeeze(0), 16000)
model-best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:691a5a2e6d878f51c7451db9a294c359a47ffa32ef4d0e8668ababddd087cf4d
3
+ size 244872425
p225_007_mic1.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:77ee23b151c88540987be1f6eb91123937723846b47847510b1d340705ab7394
3
+ size 2176
requirements.txt ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.2.0
3
+ annotated-types==0.6.0
4
+ anyio==4.2.0
5
+ attrs==23.2.0
6
+ certifi==2024.2.2
7
+ cffi==1.16.0
8
+ charset-normalizer==3.3.2
9
+ click==8.1.7
10
+ colorama==0.4.6
11
+ contourpy==1.2.0
12
+ cycler==0.12.1
13
+ exceptiongroup==1.2.0
14
+ fastapi==0.109.2
15
+ ffmpy==0.3.2
16
+ filelock==3.13.1
17
+ fonttools==4.48.1
18
+ fsspec==2023.10.0
19
+ gmpy2==2.1.2
20
+ gradio==4.18.0
21
+ gradio_client==0.10.0
22
+ h11==0.14.0
23
+ httpcore==1.0.2
24
+ httpx==0.26.0
25
+ huggingface-hub==0.20.3
26
+ idna==3.6
27
+ importlib-resources==6.1.1
28
+ Jinja2==3.1.3
29
+ joblib==1.2.0
30
+ jsonschema==4.21.1
31
+ jsonschema-specifications==2023.12.1
32
+ kiwisolver==1.4.5
33
+ markdown-it-py==3.0.0
34
+ MarkupSafe==2.1.3
35
+ matplotlib==3.8.2
36
+ mdurl==0.1.2
37
+ mkl-fft==1.3.8
38
+ mkl-random==1.2.4
39
+ mkl-service==2.4.0
40
+ mpmath==1.3.0
41
+ networkx==3.1
42
+ numpy==1.26.3
43
+ orjson==3.9.13
44
+ packaging==23.2
45
+ pandas==2.2.0
46
+ pillow==10.2.0
47
+ pip==23.3.1
48
+ pycparser==2.21
49
+ pydantic==2.6.1
50
+ pydantic_core==2.16.2
51
+ pydub==0.25.1
52
+ Pygments==2.17.2
53
+ pyparsing==3.1.1
54
+ python-dateutil==2.8.2
55
+ python-multipart==0.0.9
56
+ pytz==2024.1
57
+ PyYAML==6.0.1
58
+ referencing==0.33.0
59
+ requests==2.31.0
60
+ rich==13.7.0
61
+ rpds-py==0.17.1
62
+ ruff==0.2.1
63
+ scikit-learn==1.2.2
64
+ scipy==1.11.4
65
+ semantic-version==2.10.0
66
+ setuptools==68.2.2
67
+ shellingham==1.5.4
68
+ six==1.16.0
69
+ sniffio==1.3.0
70
+ starlette==0.36.3
71
+ sympy==1.12
72
+ threadpoolctl==2.2.0
73
+ tomlkit==0.12.0
74
+ toolz==0.12.1
75
+ torch==2.2.0
76
+ torchaudio==2.2.0
77
+ tqdm==4.66.2
78
+ typer==0.9.0
79
+ typing_extensions==4.9.0
80
+ tzdata==2024.1
81
+ urllib3==2.2.0
82
+ uvicorn==0.27.1
83
+ websockets==11.0.3
84
+ wheel==0.41.2
85
+ zipp==3.17.0