Plachta commited on
Commit
3ec1f67
·
verified ·
1 Parent(s): ec4d5dc

Create rmvpe.py

Browse files
Files changed (1) hide show
  1. modules/rmvpe.py +600 -0
modules/rmvpe.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ import os
3
+ from typing import List, Optional, Tuple
4
+ import numpy as np
5
+ import torch
6
+
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from librosa.util import normalize, pad_center, tiny
10
+ from scipy.signal import get_window
11
+
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class STFT(torch.nn.Module):
18
+ def __init__(
19
+ self, filter_length=1024, hop_length=512, win_length=None, window="hann"
20
+ ):
21
+ """
22
+ This module implements an STFT using 1D convolution and 1D transpose convolutions.
23
+ This is a bit tricky so there are some cases that probably won't work as working
24
+ out the same sizes before and after in all overlap add setups is tough. Right now,
25
+ this code should work with hop lengths that are half the filter length (50% overlap
26
+ between frames).
27
+
28
+ Keyword Arguments:
29
+ filter_length {int} -- Length of filters used (default: {1024})
30
+ hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
31
+ win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
32
+ equals the filter length). (default: {None})
33
+ window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
34
+ (default: {'hann'})
35
+ """
36
+ super(STFT, self).__init__()
37
+ self.filter_length = filter_length
38
+ self.hop_length = hop_length
39
+ self.win_length = win_length if win_length else filter_length
40
+ self.window = window
41
+ self.forward_transform = None
42
+ self.pad_amount = int(self.filter_length / 2)
43
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
44
+
45
+ cutoff = int((self.filter_length / 2 + 1))
46
+ fourier_basis = np.vstack(
47
+ [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
48
+ )
49
+ forward_basis = torch.FloatTensor(fourier_basis)
50
+ inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
51
+
52
+ assert filter_length >= self.win_length
53
+ # get window and zero center pad it to filter_length
54
+ fft_window = get_window(window, self.win_length, fftbins=True)
55
+ fft_window = pad_center(fft_window, size=filter_length)
56
+ fft_window = torch.from_numpy(fft_window).float()
57
+
58
+ # window the bases
59
+ forward_basis *= fft_window
60
+ inverse_basis = (inverse_basis.T * fft_window).T
61
+
62
+ self.register_buffer("forward_basis", forward_basis.float())
63
+ self.register_buffer("inverse_basis", inverse_basis.float())
64
+ self.register_buffer("fft_window", fft_window.float())
65
+
66
+ def transform(self, input_data, return_phase=False):
67
+ """Take input data (audio) to STFT domain.
68
+
69
+ Arguments:
70
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
71
+
72
+ Returns:
73
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
74
+ num_frequencies, num_frames)
75
+ phase {tensor} -- Phase of STFT with shape (num_batch,
76
+ num_frequencies, num_frames)
77
+ """
78
+ input_data = F.pad(
79
+ input_data,
80
+ (self.pad_amount, self.pad_amount),
81
+ mode="reflect",
82
+ )
83
+ forward_transform = input_data.unfold(
84
+ 1, self.filter_length, self.hop_length
85
+ ).permute(0, 2, 1)
86
+ forward_transform = torch.matmul(self.forward_basis, forward_transform)
87
+ cutoff = int((self.filter_length / 2) + 1)
88
+ real_part = forward_transform[:, :cutoff, :]
89
+ imag_part = forward_transform[:, cutoff:, :]
90
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
91
+ if return_phase:
92
+ phase = torch.atan2(imag_part.data, real_part.data)
93
+ return magnitude, phase
94
+ else:
95
+ return magnitude
96
+
97
+ def inverse(self, magnitude, phase):
98
+ """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
99
+ by the ```transform``` function.
100
+
101
+ Arguments:
102
+ magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
103
+ num_frequencies, num_frames)
104
+ phase {tensor} -- Phase of STFT with shape (num_batch,
105
+ num_frequencies, num_frames)
106
+
107
+ Returns:
108
+ inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
109
+ shape (num_batch, num_samples)
110
+ """
111
+ cat = torch.cat(
112
+ [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
113
+ )
114
+ fold = torch.nn.Fold(
115
+ output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
116
+ kernel_size=(1, self.filter_length),
117
+ stride=(1, self.hop_length),
118
+ )
119
+ inverse_transform = torch.matmul(self.inverse_basis, cat)
120
+ inverse_transform = fold(inverse_transform)[
121
+ :, 0, 0, self.pad_amount : -self.pad_amount
122
+ ]
123
+ window_square_sum = (
124
+ self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
125
+ )
126
+ window_square_sum = fold(window_square_sum)[
127
+ :, 0, 0, self.pad_amount : -self.pad_amount
128
+ ]
129
+ inverse_transform /= window_square_sum
130
+ return inverse_transform
131
+
132
+ def forward(self, input_data):
133
+ """Take input data (audio) to STFT domain and then back to audio.
134
+
135
+ Arguments:
136
+ input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
137
+
138
+ Returns:
139
+ reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
140
+ shape (num_batch, num_samples)
141
+ """
142
+ self.magnitude, self.phase = self.transform(input_data, return_phase=True)
143
+ reconstruction = self.inverse(self.magnitude, self.phase)
144
+ return reconstruction
145
+
146
+
147
+ from time import time as ttime
148
+
149
+
150
+ class BiGRU(nn.Module):
151
+ def __init__(self, input_features, hidden_features, num_layers):
152
+ super(BiGRU, self).__init__()
153
+ self.gru = nn.GRU(
154
+ input_features,
155
+ hidden_features,
156
+ num_layers=num_layers,
157
+ batch_first=True,
158
+ bidirectional=True,
159
+ )
160
+
161
+ def forward(self, x):
162
+ return self.gru(x)[0]
163
+
164
+
165
+ class ConvBlockRes(nn.Module):
166
+ def __init__(self, in_channels, out_channels, momentum=0.01):
167
+ super(ConvBlockRes, self).__init__()
168
+ self.conv = nn.Sequential(
169
+ nn.Conv2d(
170
+ in_channels=in_channels,
171
+ out_channels=out_channels,
172
+ kernel_size=(3, 3),
173
+ stride=(1, 1),
174
+ padding=(1, 1),
175
+ bias=False,
176
+ ),
177
+ nn.BatchNorm2d(out_channels, momentum=momentum),
178
+ nn.ReLU(),
179
+ nn.Conv2d(
180
+ in_channels=out_channels,
181
+ out_channels=out_channels,
182
+ kernel_size=(3, 3),
183
+ stride=(1, 1),
184
+ padding=(1, 1),
185
+ bias=False,
186
+ ),
187
+ nn.BatchNorm2d(out_channels, momentum=momentum),
188
+ nn.ReLU(),
189
+ )
190
+ # self.shortcut:Optional[nn.Module] = None
191
+ if in_channels != out_channels:
192
+ self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
193
+
194
+ def forward(self, x: torch.Tensor):
195
+ if not hasattr(self, "shortcut"):
196
+ return self.conv(x) + x
197
+ else:
198
+ return self.conv(x) + self.shortcut(x)
199
+
200
+
201
+ class Encoder(nn.Module):
202
+ def __init__(
203
+ self,
204
+ in_channels,
205
+ in_size,
206
+ n_encoders,
207
+ kernel_size,
208
+ n_blocks,
209
+ out_channels=16,
210
+ momentum=0.01,
211
+ ):
212
+ super(Encoder, self).__init__()
213
+ self.n_encoders = n_encoders
214
+ self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
215
+ self.layers = nn.ModuleList()
216
+ self.latent_channels = []
217
+ for i in range(self.n_encoders):
218
+ self.layers.append(
219
+ ResEncoderBlock(
220
+ in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
221
+ )
222
+ )
223
+ self.latent_channels.append([out_channels, in_size])
224
+ in_channels = out_channels
225
+ out_channels *= 2
226
+ in_size //= 2
227
+ self.out_size = in_size
228
+ self.out_channel = out_channels
229
+
230
+ def forward(self, x: torch.Tensor):
231
+ concat_tensors: List[torch.Tensor] = []
232
+ x = self.bn(x)
233
+ for i, layer in enumerate(self.layers):
234
+ t, x = layer(x)
235
+ concat_tensors.append(t)
236
+ return x, concat_tensors
237
+
238
+
239
+ class ResEncoderBlock(nn.Module):
240
+ def __init__(
241
+ self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
242
+ ):
243
+ super(ResEncoderBlock, self).__init__()
244
+ self.n_blocks = n_blocks
245
+ self.conv = nn.ModuleList()
246
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
247
+ for i in range(n_blocks - 1):
248
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
249
+ self.kernel_size = kernel_size
250
+ if self.kernel_size is not None:
251
+ self.pool = nn.AvgPool2d(kernel_size=kernel_size)
252
+
253
+ def forward(self, x):
254
+ for i, conv in enumerate(self.conv):
255
+ x = conv(x)
256
+ if self.kernel_size is not None:
257
+ return x, self.pool(x)
258
+ else:
259
+ return x
260
+
261
+
262
+ class Intermediate(nn.Module): #
263
+ def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
264
+ super(Intermediate, self).__init__()
265
+ self.n_inters = n_inters
266
+ self.layers = nn.ModuleList()
267
+ self.layers.append(
268
+ ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
269
+ )
270
+ for i in range(self.n_inters - 1):
271
+ self.layers.append(
272
+ ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
273
+ )
274
+
275
+ def forward(self, x):
276
+ for i, layer in enumerate(self.layers):
277
+ x = layer(x)
278
+ return x
279
+
280
+
281
+ class ResDecoderBlock(nn.Module):
282
+ def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
283
+ super(ResDecoderBlock, self).__init__()
284
+ out_padding = (0, 1) if stride == (1, 2) else (1, 1)
285
+ self.n_blocks = n_blocks
286
+ self.conv1 = nn.Sequential(
287
+ nn.ConvTranspose2d(
288
+ in_channels=in_channels,
289
+ out_channels=out_channels,
290
+ kernel_size=(3, 3),
291
+ stride=stride,
292
+ padding=(1, 1),
293
+ output_padding=out_padding,
294
+ bias=False,
295
+ ),
296
+ nn.BatchNorm2d(out_channels, momentum=momentum),
297
+ nn.ReLU(),
298
+ )
299
+ self.conv2 = nn.ModuleList()
300
+ self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
301
+ for i in range(n_blocks - 1):
302
+ self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
303
+
304
+ def forward(self, x, concat_tensor):
305
+ x = self.conv1(x)
306
+ x = torch.cat((x, concat_tensor), dim=1)
307
+ for i, conv2 in enumerate(self.conv2):
308
+ x = conv2(x)
309
+ return x
310
+
311
+
312
+ class Decoder(nn.Module):
313
+ def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
314
+ super(Decoder, self).__init__()
315
+ self.layers = nn.ModuleList()
316
+ self.n_decoders = n_decoders
317
+ for i in range(self.n_decoders):
318
+ out_channels = in_channels // 2
319
+ self.layers.append(
320
+ ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
321
+ )
322
+ in_channels = out_channels
323
+
324
+ def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
325
+ for i, layer in enumerate(self.layers):
326
+ x = layer(x, concat_tensors[-1 - i])
327
+ return x
328
+
329
+
330
+ class DeepUnet(nn.Module):
331
+ def __init__(
332
+ self,
333
+ kernel_size,
334
+ n_blocks,
335
+ en_de_layers=5,
336
+ inter_layers=4,
337
+ in_channels=1,
338
+ en_out_channels=16,
339
+ ):
340
+ super(DeepUnet, self).__init__()
341
+ self.encoder = Encoder(
342
+ in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
343
+ )
344
+ self.intermediate = Intermediate(
345
+ self.encoder.out_channel // 2,
346
+ self.encoder.out_channel,
347
+ inter_layers,
348
+ n_blocks,
349
+ )
350
+ self.decoder = Decoder(
351
+ self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
352
+ )
353
+
354
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
355
+ x, concat_tensors = self.encoder(x)
356
+ x = self.intermediate(x)
357
+ x = self.decoder(x, concat_tensors)
358
+ return x
359
+
360
+
361
+ class E2E(nn.Module):
362
+ def __init__(
363
+ self,
364
+ n_blocks,
365
+ n_gru,
366
+ kernel_size,
367
+ en_de_layers=5,
368
+ inter_layers=4,
369
+ in_channels=1,
370
+ en_out_channels=16,
371
+ ):
372
+ super(E2E, self).__init__()
373
+ self.unet = DeepUnet(
374
+ kernel_size,
375
+ n_blocks,
376
+ en_de_layers,
377
+ inter_layers,
378
+ in_channels,
379
+ en_out_channels,
380
+ )
381
+ self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
382
+ if n_gru:
383
+ self.fc = nn.Sequential(
384
+ BiGRU(3 * 128, 256, n_gru),
385
+ nn.Linear(512, 360),
386
+ nn.Dropout(0.25),
387
+ nn.Sigmoid(),
388
+ )
389
+ else:
390
+ self.fc = nn.Sequential(
391
+ nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
392
+ )
393
+
394
+ def forward(self, mel):
395
+ # print(mel.shape)
396
+ mel = mel.transpose(-1, -2).unsqueeze(1)
397
+ x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
398
+ x = self.fc(x)
399
+ # print(x.shape)
400
+ return x
401
+
402
+
403
+ from librosa.filters import mel
404
+
405
+
406
+ class MelSpectrogram(torch.nn.Module):
407
+ def __init__(
408
+ self,
409
+ is_half,
410
+ n_mel_channels,
411
+ sampling_rate,
412
+ win_length,
413
+ hop_length,
414
+ n_fft=None,
415
+ mel_fmin=0,
416
+ mel_fmax=None,
417
+ clamp=1e-5,
418
+ ):
419
+ super().__init__()
420
+ n_fft = win_length if n_fft is None else n_fft
421
+ self.hann_window = {}
422
+ mel_basis = mel(
423
+ sr=sampling_rate,
424
+ n_fft=n_fft,
425
+ n_mels=n_mel_channels,
426
+ fmin=mel_fmin,
427
+ fmax=mel_fmax,
428
+ htk=True,
429
+ )
430
+ mel_basis = torch.from_numpy(mel_basis).float()
431
+ self.register_buffer("mel_basis", mel_basis)
432
+ self.n_fft = win_length if n_fft is None else n_fft
433
+ self.hop_length = hop_length
434
+ self.win_length = win_length
435
+ self.sampling_rate = sampling_rate
436
+ self.n_mel_channels = n_mel_channels
437
+ self.clamp = clamp
438
+ self.is_half = is_half
439
+
440
+ def forward(self, audio, keyshift=0, speed=1, center=True):
441
+ factor = 2 ** (keyshift / 12)
442
+ n_fft_new = int(np.round(self.n_fft * factor))
443
+ win_length_new = int(np.round(self.win_length * factor))
444
+ hop_length_new = int(np.round(self.hop_length * speed))
445
+ keyshift_key = str(keyshift) + "_" + str(audio.device)
446
+ if keyshift_key not in self.hann_window:
447
+ self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
448
+ audio.device
449
+ )
450
+ if "privateuseone" in str(audio.device):
451
+ if not hasattr(self, "stft"):
452
+ self.stft = STFT(
453
+ filter_length=n_fft_new,
454
+ hop_length=hop_length_new,
455
+ win_length=win_length_new,
456
+ window="hann",
457
+ ).to(audio.device)
458
+ magnitude = self.stft.transform(audio)
459
+ else:
460
+ fft = torch.stft(
461
+ audio,
462
+ n_fft=n_fft_new,
463
+ hop_length=hop_length_new,
464
+ win_length=win_length_new,
465
+ window=self.hann_window[keyshift_key],
466
+ center=center,
467
+ return_complex=True,
468
+ )
469
+ magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
470
+ if keyshift != 0:
471
+ size = self.n_fft // 2 + 1
472
+ resize = magnitude.size(1)
473
+ if resize < size:
474
+ magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
475
+ magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
476
+ mel_output = torch.matmul(self.mel_basis, magnitude)
477
+ if self.is_half == True:
478
+ mel_output = mel_output.half()
479
+ log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
480
+ return log_mel_spec
481
+
482
+
483
+ class RMVPE:
484
+ def __init__(self, model_path: str, is_half, device=None, use_jit=False):
485
+ self.resample_kernel = {}
486
+ self.resample_kernel = {}
487
+ self.is_half = is_half
488
+ if device is None:
489
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
490
+ self.device = device
491
+ self.mel_extractor = MelSpectrogram(
492
+ is_half, 128, 16000, 1024, 160, None, 30, 8000
493
+ ).to(device)
494
+ if "privateuseone" in str(device):
495
+ import onnxruntime as ort
496
+
497
+ ort_session = ort.InferenceSession(
498
+ "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
499
+ providers=["DmlExecutionProvider"],
500
+ )
501
+ self.model = ort_session
502
+ else:
503
+ if str(self.device) == "cuda":
504
+ self.device = torch.device("cuda:0")
505
+
506
+ def get_default_model():
507
+ model = E2E(4, 1, (2, 2))
508
+ ckpt = torch.load(model_path, map_location="cpu")
509
+ model.load_state_dict(ckpt)
510
+ model.eval()
511
+ if is_half:
512
+ model = model.half()
513
+ else:
514
+ model = model.float()
515
+ return model
516
+
517
+ self.model = get_default_model()
518
+
519
+ self.model = self.model.to(device)
520
+ cents_mapping = 20 * np.arange(360) + 1997.3794084376191
521
+ self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
522
+
523
+ def mel2hidden(self, mel):
524
+ with torch.no_grad():
525
+ n_frames = mel.shape[-1]
526
+ n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
527
+ if n_pad > 0:
528
+ mel = F.pad(mel, (0, n_pad), mode="constant")
529
+ if "privateuseone" in str(self.device):
530
+ onnx_input_name = self.model.get_inputs()[0].name
531
+ onnx_outputs_names = self.model.get_outputs()[0].name
532
+ hidden = self.model.run(
533
+ [onnx_outputs_names],
534
+ input_feed={onnx_input_name: mel.cpu().numpy()},
535
+ )[0]
536
+ else:
537
+ mel = mel.half() if self.is_half else mel.float()
538
+ hidden = self.model(mel)
539
+ return hidden[:, :n_frames]
540
+
541
+ def decode(self, hidden, thred=0.03):
542
+ cents_pred = self.to_local_average_cents(hidden, thred=thred)
543
+ f0 = 10 * (2 ** (cents_pred / 1200))
544
+ f0[f0 == 10] = 0
545
+ # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
546
+ return f0
547
+
548
+ def infer_from_audio(self, audio, thred=0.03):
549
+ # torch.cuda.synchronize()
550
+ # t0 = ttime()
551
+ if not torch.is_tensor(audio):
552
+ audio = torch.from_numpy(audio)
553
+ mel = self.mel_extractor(
554
+ audio.float().to(self.device).unsqueeze(0), center=True
555
+ )
556
+ # print(123123123,mel.device.type)
557
+ # torch.cuda.synchronize()
558
+ # t1 = ttime()
559
+ hidden = self.mel2hidden(mel)
560
+ # torch.cuda.synchronize()
561
+ # t2 = ttime()
562
+ # print(234234,hidden.device.type)
563
+ if "privateuseone" not in str(self.device):
564
+ hidden = hidden.squeeze(0).cpu().numpy()
565
+ else:
566
+ hidden = hidden[0]
567
+ if self.is_half == True:
568
+ hidden = hidden.astype("float32")
569
+
570
+ f0 = self.decode(hidden, thred=thred)
571
+ # torch.cuda.synchronize()
572
+ # t3 = ttime()
573
+ # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
574
+ return f0
575
+
576
+ def to_local_average_cents(self, salience, thred=0.05):
577
+ # t0 = ttime()
578
+ center = np.argmax(salience, axis=1) # 帧长#index
579
+ salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
580
+ # t1 = ttime()
581
+ center += 4
582
+ todo_salience = []
583
+ todo_cents_mapping = []
584
+ starts = center - 4
585
+ ends = center + 5
586
+ for idx in range(salience.shape[0]):
587
+ todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
588
+ todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
589
+ # t2 = ttime()
590
+ todo_salience = np.array(todo_salience) # 帧长,9
591
+ todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
592
+ product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
593
+ weight_sum = np.sum(todo_salience, 1) # 帧长
594
+ devided = product_sum / weight_sum # 帧长
595
+ # t3 = ttime()
596
+ maxx = np.max(salience, axis=1) # 帧长
597
+ devided[maxx <= thred] = 0
598
+ # t4 = ttime()
599
+ # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
600
+ return devided