KingNish commited on
Commit
0fc9117
·
verified ·
1 Parent(s): bd3b355

Upload ./vocos/feature_extractors.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. vocos/feature_extractors.py +120 -0
vocos/feature_extractors.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torchaudio
5
+ # from encodec import EncodecModel
6
+ from torch import nn
7
+ from omegaconf import OmegaConf
8
+ from vocos.modules import safe_log
9
+ # from vocos.xcodec.models.soundstream_hubert_new import SoundStream
10
+
11
+ class FeatureExtractor(nn.Module):
12
+ """Base class for feature extractors."""
13
+
14
+ def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
15
+ """
16
+ Extract features from the given audio.
17
+
18
+ Args:
19
+ audio (Tensor): Input audio waveform.
20
+
21
+ Returns:
22
+ Tensor: Extracted features of shape (B, C, L), where B is the batch size,
23
+ C denotes output features, and L is the sequence length.
24
+ """
25
+ raise NotImplementedError("Subclasses must implement the forward method.")
26
+
27
+
28
+ class MelSpectrogramFeatures(FeatureExtractor):
29
+ def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
30
+ super().__init__()
31
+ if padding not in ["center", "same"]:
32
+ raise ValueError("Padding must be 'center' or 'same'.")
33
+ self.padding = padding
34
+ self.mel_spec = torchaudio.transforms.MelSpectrogram(
35
+ sample_rate=sample_rate,
36
+ n_fft=n_fft,
37
+ hop_length=hop_length,
38
+ n_mels=n_mels,
39
+ center=padding == "center",
40
+ power=1,
41
+ )
42
+
43
+ def forward(self, audio, **kwargs):
44
+ if self.padding == "same":
45
+ pad = self.mel_spec.win_length - self.mel_spec.hop_length
46
+ audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
47
+ mel = self.mel_spec(audio)
48
+ features = safe_log(mel)
49
+ return features
50
+
51
+
52
+ class EncodecFeatures(FeatureExtractor):
53
+ def __init__(
54
+ self,
55
+ encodec_model: str = "encodec_24khz",
56
+ bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
57
+ train_codebooks: bool = False,
58
+ ):
59
+ super().__init__()
60
+ if encodec_model == "encodec_24khz":
61
+ encodec = EncodecModel.encodec_model_24khz
62
+ elif encodec_model == "encodec_48khz":
63
+ encodec = EncodecModel.encodec_model_48khz
64
+ else:
65
+ raise ValueError(
66
+ f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'."
67
+ )
68
+ self.encodec = encodec(pretrained=True)
69
+ for param in self.encodec.parameters():
70
+ param.requires_grad = False
71
+ self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth(
72
+ self.encodec.frame_rate, bandwidth=max(bandwidths)
73
+ )
74
+ codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
75
+ self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
76
+ self.bandwidths = bandwidths
77
+
78
+ @torch.no_grad()
79
+ def get_encodec_codes(self, audio):
80
+ audio = audio.unsqueeze(1)
81
+ emb = self.encodec.encoder(audio)
82
+ codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
83
+ return codes
84
+
85
+ def forward(self, audio: torch.Tensor, **kwargs):
86
+ bandwidth_id = kwargs.get("bandwidth_id")
87
+ if bandwidth_id is None:
88
+ raise ValueError("The 'bandwidth_id' argument is required")
89
+ self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode
90
+ self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id])
91
+ codes = self.get_encodec_codes(audio)
92
+ # Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
93
+ # with offsets given by the number of bins, and finally summed in a vectorized operation.
94
+ offsets = torch.arange(
95
+ 0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
96
+ )
97
+ embeddings_idxs = codes + offsets.view(-1, 1, 1)
98
+ features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
99
+ return features.transpose(1, 2)
100
+
101
+ class xCodecFeatures(FeatureExtractor):
102
+ def __init__(
103
+ self,
104
+ config: str,
105
+ ckpt: str,
106
+ ):
107
+ super().__init__()
108
+ self.config = OmegaConf.load(config)
109
+ self.model = eval(self.config.generator.name)(**self.config.generator.config)
110
+ parameter_dict = torch.load(ckpt, map_location='cpu')
111
+ self.model.load_state_dict(parameter_dict['codec_model'])
112
+ self.resampler = torchaudio.transforms.Resample(orig_freq=44100, new_freq=16000).to('cuda')
113
+ self.model.eval()
114
+
115
+ def forward(self, audio: torch.Tensor):
116
+ # resample audio from 44100 to 16000
117
+ audio = self.resampler(audio)
118
+ with torch.no_grad():
119
+ codes = self.model.encode(audio, target_bw=6)
120
+ return codes