uzdzn commited on
Commit
33079ce
1 Parent(s): 47adc7e

Delete decoder.py

Browse files
Files changed (1) hide show
  1. decoder.py +0 -345
decoder.py DELETED
@@ -1,345 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
5
-
6
- class CustomLSTM(nn.Module):
7
- def __init__(self, input_sz, hidden_sz):
8
- super().__init__()
9
- self.input_sz = input_sz
10
- self.hidden_size = hidden_sz
11
- self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz * 4))
12
- self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))
13
- self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4))
14
- self.init_weights()
15
-
16
- def init_weights(self):
17
- stdv = 1.0 / math.sqrt(self.hidden_size)
18
- for weight in self.parameters():
19
- weight.data.uniform_(-stdv, stdv)
20
-
21
- def forward(self, x,
22
- init_states=None):
23
- """Assumes x is of shape (batch, sequence, feature)"""
24
- #print(type(x))
25
- #print(x.shape)
26
- bs, seq_sz, _ = x.size()
27
- hidden_seq = []
28
- if init_states is None:
29
- h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device),
30
- torch.zeros(bs, self.hidden_size).to(x.device))
31
- else:
32
- h_t, c_t = init_states
33
-
34
- HS = self.hidden_size
35
- for t in range(seq_sz):
36
- x_t = x[:, t, :]
37
- # batch the computations into a single matrix multiplication
38
- gates = x_t @ self.W + h_t @ self.U + self.bias
39
- i_t, f_t, g_t, o_t = (
40
- torch.sigmoid(gates[:, :HS]), # input
41
- torch.sigmoid(gates[:, HS:HS*2]), # forget
42
- torch.tanh(gates[:, HS*2:HS*3]),
43
- torch.sigmoid(gates[:, HS*3:]), # output
44
- )
45
- c_t = f_t * c_t + i_t * g_t
46
- h_t = o_t * torch.tanh(c_t)
47
- hidden_seq.append(h_t.unsqueeze(0))
48
- hidden_seq = torch.cat(hidden_seq, dim=0)
49
- # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
50
- hidden_seq = hidden_seq.transpose(0, 1).contiguous()
51
- return hidden_seq, (h_t, c_t)
52
-
53
- hparams = {
54
- 'n_mel_channels': 128, # From LogMelSpectrogram
55
- 'postnet_embedding_dim': 512, # Common choice, adjust as needed
56
- 'postnet_kernel_size': 5, # Common choice, adjust as needed
57
- 'postnet_n_convolutions': 5, # Typical number of Postnet convolutions
58
- }
59
-
60
- class ConvNorm(torch.nn.Module):
61
- def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
62
- padding=None, dilation=1, bias=True, w_init_gain='linear'):
63
- super(ConvNorm, self).__init__()
64
- if padding is None:
65
- assert(kernel_size % 2 == 1)
66
- padding = int(dilation * (kernel_size - 1) / 2)
67
-
68
- self.conv = torch.nn.Conv1d(in_channels, out_channels,
69
- kernel_size=kernel_size, stride=stride,
70
- padding=padding, dilation=dilation,
71
- bias=bias)
72
-
73
- torch.nn.init.xavier_uniform_(
74
- self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
75
-
76
- def forward(self, signal):
77
- conv_signal = self.conv(signal)
78
- return conv_signal
79
-
80
- URLS = {
81
- "hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt",
82
- "hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt",
83
- }
84
-
85
-
86
- class AcousticModel(nn.Module):
87
- def __init__(self, discrete: bool = False, upsample: bool = True, use_custom_lstm=False):
88
- super().__init__()
89
- # self.spk_projection = nn.Linear(512+512, 512)
90
- self.encoder = Encoder(discrete, upsample)
91
- self.decoder = Decoder(use_custom_lstm=use_custom_lstm)
92
- self.postnet = Postnet(hparams) # Add this line. Ensure hparams is defined or pass explicit parameters
93
-
94
- def forward(self, x: torch.Tensor, spk_embs, mels: torch.Tensor) -> torch.Tensor:
95
- x = self.encoder(x)
96
- exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
97
- concat_x = torch.cat([x, exp_spk_embs], dim=-1)
98
- # x = self.spk_projection(concat_x)
99
- output = self.decoder(concat_x, mels)
100
- postnet_output = self.postnet(output) + output
101
- return postnet_output
102
-
103
- #def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
104
- # x = self.encoder(x)
105
- # return self.decoder(x, mels)
106
-
107
- def forward_test(self, x, spk_embs, mels):
108
- print('x shape', x.shape)
109
- print('se shape', spk_embs.shape)
110
- print('mels shape', mels.shape)
111
- x = self.encoder(x)
112
- print('x_enc shape', x.shape)
113
- return
114
-
115
- @torch.inference_mode()
116
- def generate(self, x: torch.Tensor, spk_embs) -> torch.Tensor:
117
- x = self.encoder(x)
118
- exp_spk_embs = spk_embs.unsqueeze(1).expand(-1, x.size(1), -1)
119
- concat_x = torch.cat([x, exp_spk_embs], dim=-1)
120
- # x = self.spk_projection(concat_x)
121
- mels = self.decoder.generate(concat_x)
122
- postnet_mels = self.postnet(mels) + mels
123
- return postnet_mels
124
-
125
-
126
- class Encoder(nn.Module):
127
- def __init__(self, discrete: bool = False, upsample: bool = True):
128
- super().__init__()
129
- self.embedding = nn.Embedding(100 + 1, 256) if discrete else None
130
- self.prenet = PreNet(256, 256, 256)
131
- self.convs = nn.Sequential(
132
- nn.Conv1d(256, 512, 5, 1, 2),
133
- nn.ReLU(),
134
- nn.Dropout(0.3),
135
- nn.InstanceNorm1d(512),
136
- nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(),
137
- nn.Dropout(0.3),
138
- nn.Conv1d(512, 512, 5, 1, 2),
139
- nn.ReLU(),
140
- nn.Dropout(0.3),
141
- nn.InstanceNorm1d(512),
142
- nn.Conv1d(512, 512, 5, 1, 2),
143
- nn.ReLU(),
144
- nn.Dropout(0.3),
145
- nn.InstanceNorm1d(512),
146
- )
147
-
148
- def forward(self, x: torch.Tensor) -> torch.Tensor:
149
- if self.embedding is not None:
150
- x = self.embedding(x)
151
- x = self.prenet(x)
152
- x = self.convs(x.transpose(1, 2))
153
- return x.transpose(1, 2)
154
-
155
-
156
- class Decoder(nn.Module):
157
- def __init__(self, use_custom_lstm=False):
158
- super().__init__()
159
- self.use_custom_lstm = use_custom_lstm
160
- self.prenet = PreNet(128, 256, 256)
161
- if use_custom_lstm:
162
- self.lstm1 = CustomLSTM(1024 + 256, 1024)
163
- self.lstm2 = CustomLSTM(1024, 1024)
164
- self.lstm3 = CustomLSTM(1024, 1024)
165
- else:
166
- self.lstm1 = nn.LSTM(1024 + 256, 1024)
167
- self.lstm2 = nn.LSTM(1024, 1024)
168
- self.lstm3 = nn.LSTM(1024, 1024)
169
- self.proj = nn.Linear(1024, 128, bias=False)
170
- self.dropout = nn.Dropout(0.3)
171
-
172
- def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
173
- mels = self.prenet(mels)
174
- x, _ = self.lstm1(torch.cat((x, mels), dim=-1))
175
- x = self.dropout(x)
176
- res = x
177
- x, _ = self.lstm2(x)
178
- x = self.dropout(x)
179
- x = res + x
180
- res = x
181
- x, _ = self.lstm3(x)
182
- x = self.dropout(x)
183
- x = res + x
184
- return self.proj(x)
185
-
186
- @torch.inference_mode()
187
- def generate(self, xs: torch.Tensor) -> torch.Tensor:
188
- m = torch.zeros(xs.size(0), 128, device=xs.device)
189
- if self.use_custom_lstm:
190
- h1 = torch.zeros(xs.size(0), 1024, device=xs.device)
191
- c1 = torch.zeros(xs.size(0), 1024, device=xs.device)
192
- h2 = torch.zeros(xs.size(0), 1024, device=xs.device)
193
- c2 = torch.zeros(xs.size(0), 1024, device=xs.device)
194
- h3 = torch.zeros(xs.size(0), 1024, device=xs.device)
195
- c3 = torch.zeros(xs.size(0), 1024, device=xs.device)
196
- else:
197
- h1 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
198
- c1 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
199
- h2 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
200
- c2 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
201
- h3 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
202
- c3 = torch.zeros(1, xs.size(0), 1024, device=xs.device)
203
-
204
-
205
- mel = []
206
- for x in torch.unbind(xs, dim=1):
207
- m = self.prenet(m)
208
- x = torch.cat((x, m), dim=1).unsqueeze(1)
209
- x1, (h1, c1) = self.lstm1(x, (h1, c1))
210
- x2, (h2, c2) = self.lstm2(x1, (h2, c2))
211
- x = x1 + x2
212
- x3, (h3, c3) = self.lstm3(x, (h3, c3))
213
- x = x + x3
214
- m = self.proj(x).squeeze(1)
215
- mel.append(m)
216
- return torch.stack(mel, dim=1)
217
-
218
-
219
- class PreNet(nn.Module):
220
- def __init__(
221
- self,
222
- input_size: int,
223
- hidden_size: int,
224
- output_size: int,
225
- dropout: float = 0.5,
226
- ):
227
- super().__init__()
228
- self.net = nn.Sequential(
229
- nn.Linear(input_size, hidden_size),
230
- nn.ReLU(),
231
- nn.Dropout(dropout),
232
- nn.Linear(hidden_size, output_size),
233
- nn.ReLU(),
234
- nn.Dropout(dropout),
235
- )
236
-
237
- def forward(self, x: torch.Tensor) -> torch.Tensor:
238
- return self.net(x)
239
-
240
-
241
- def _acoustic(
242
- name: str,
243
- discrete: bool,
244
- upsample: bool,
245
- pretrained: bool = True,
246
- progress: bool = True,
247
- ) -> AcousticModel:
248
- acoustic = AcousticModel(discrete, upsample)
249
- if pretrained:
250
- checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
251
- consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
252
- acoustic.load_state_dict(checkpoint["acoustic-model"])
253
- acoustic.eval()
254
- return acoustic
255
-
256
-
257
- def hubert_discrete(
258
- pretrained: bool = True,
259
- progress: bool = True,
260
- ) -> AcousticModel:
261
- r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
262
- Args:
263
- pretrained (bool): load pretrained weights into the model
264
- progress (bool): show progress bar when downloading model
265
- """
266
- return _acoustic(
267
- "hubert-discrete",
268
- discrete=True,
269
- upsample=True,
270
- pretrained=pretrained,
271
- progress=progress,
272
- )
273
-
274
-
275
- def hubert_soft(
276
- pretrained: bool = True,
277
- progress: bool = True,
278
- ) -> AcousticModel:
279
- r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
280
- Args:
281
- pretrained (bool): load pretrained weights into the model
282
- progress (bool): show progress bar when downloading model
283
- """
284
- return _acoustic(
285
- "hubert-soft",
286
- discrete=False,
287
- upsample=True,
288
- pretrained=pretrained,
289
- progress=progress,
290
- )
291
-
292
- class Postnet(nn.Module):
293
- def __init__(self, hparams):
294
- super(Postnet, self).__init__()
295
- self.convolutions = nn.ModuleList()
296
-
297
- self.convolutions.append(
298
- nn.Sequential(
299
- ConvNorm(in_channels=hparams['n_mel_channels'], # Adjusted input channels
300
- out_channels=hparams['postnet_embedding_dim'], # Output channels remain the same
301
- kernel_size=hparams['postnet_kernel_size'], stride=1,
302
- padding=int((hparams['postnet_kernel_size'] - 1) / 2), # Dynamic padding
303
- dilation=1, bias=True, w_init_gain='tanh'),
304
- nn.BatchNorm1d(hparams['postnet_embedding_dim'])
305
- )
306
- )
307
-
308
- for i in range(1, hparams['postnet_n_convolutions'] - 1):
309
- self.convolutions.append(
310
- nn.Sequential(
311
- ConvNorm(hparams['postnet_embedding_dim'],
312
- hparams['postnet_embedding_dim'],
313
- kernel_size=hparams['postnet_kernel_size'], stride=1,
314
- padding=int((hparams['postnet_kernel_size'] - 1) / 2), # Dynamic padding
315
- dilation=1, w_init_gain='tanh'),
316
- nn.BatchNorm1d(hparams['postnet_embedding_dim'])
317
- )
318
- )
319
-
320
- self.convolutions.append(
321
- nn.Sequential(
322
- ConvNorm(hparams['postnet_embedding_dim'], hparams['n_mel_channels'],
323
- kernel_size=hparams['postnet_kernel_size'], stride=1,
324
- padding=int((hparams['postnet_kernel_size'] - 1) / 2), # Dynamic padding
325
- dilation=1, w_init_gain='linear'),
326
- nn.BatchNorm1d(hparams['n_mel_channels'])
327
- )
328
- )
329
-
330
- def forward(self, x):
331
- #print(f"Input shape to Postnet: {x.shape}")
332
- x = x.transpose(1, 2)
333
- for i, conv in enumerate(self.convolutions[:-1]):
334
- x = conv(x)
335
- #print(f"Shape after Convolution {i+1}: {x.shape}")
336
- x = torch.tanh(x)
337
- x = F.dropout(x, 0.5, self.training)
338
-
339
- # Last layer
340
- x = self.convolutions[-1](x)
341
- #print(f"Shape after last Convolution: {x.shape}")
342
- x = F.dropout(x, 0.5, self.training)
343
- x = x.transpose(1, 2)
344
-
345
- return x