YuE-music-generator-demo / vocos /feature_extractors.py
KingNish's picture
Upload ./vocos/feature_extractors.py with huggingface_hub
0fc9117 verified
raw
history blame
4.92 kB
from typing import List
import torch
import torchaudio
# from encodec import EncodecModel
from torch import nn
from omegaconf import OmegaConf
from vocos.modules import safe_log
# from vocos.xcodec.models.soundstream_hubert_new import SoundStream
class FeatureExtractor(nn.Module):
"""Base class for feature extractors."""
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Extract features from the given audio.
Args:
audio (Tensor): Input audio waveform.
Returns:
Tensor: Extracted features of shape (B, C, L), where B is the batch size,
C denotes output features, and L is the sequence length.
"""
raise NotImplementedError("Subclasses must implement the forward method.")
class MelSpectrogramFeatures(FeatureExtractor):
def __init__(self, sample_rate=24000, n_fft=1024, hop_length=256, n_mels=100, padding="center"):
super().__init__()
if padding not in ["center", "same"]:
raise ValueError("Padding must be 'center' or 'same'.")
self.padding = padding
self.mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_mels=n_mels,
center=padding == "center",
power=1,
)
def forward(self, audio, **kwargs):
if self.padding == "same":
pad = self.mel_spec.win_length - self.mel_spec.hop_length
audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect")
mel = self.mel_spec(audio)
features = safe_log(mel)
return features
class EncodecFeatures(FeatureExtractor):
def __init__(
self,
encodec_model: str = "encodec_24khz",
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
train_codebooks: bool = False,
):
super().__init__()
if encodec_model == "encodec_24khz":
encodec = EncodecModel.encodec_model_24khz
elif encodec_model == "encodec_48khz":
encodec = EncodecModel.encodec_model_48khz
else:
raise ValueError(
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'."
)
self.encodec = encodec(pretrained=True)
for param in self.encodec.parameters():
param.requires_grad = False
self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth(
self.encodec.frame_rate, bandwidth=max(bandwidths)
)
codebook_weights = torch.cat([vq.codebook for vq in self.encodec.quantizer.vq.layers[: self.num_q]], dim=0)
self.codebook_weights = torch.nn.Parameter(codebook_weights, requires_grad=train_codebooks)
self.bandwidths = bandwidths
@torch.no_grad()
def get_encodec_codes(self, audio):
audio = audio.unsqueeze(1)
emb = self.encodec.encoder(audio)
codes = self.encodec.quantizer.encode(emb, self.encodec.frame_rate, self.encodec.bandwidth)
return codes
def forward(self, audio: torch.Tensor, **kwargs):
bandwidth_id = kwargs.get("bandwidth_id")
if bandwidth_id is None:
raise ValueError("The 'bandwidth_id' argument is required")
self.encodec.eval() # Force eval mode as Pytorch Lightning automatically sets child modules to training mode
self.encodec.set_target_bandwidth(self.bandwidths[bandwidth_id])
codes = self.get_encodec_codes(audio)
# Instead of summing in the loop, it stores subsequent VQ dictionaries in a single `self.codebook_weights`
# with offsets given by the number of bins, and finally summed in a vectorized operation.
offsets = torch.arange(
0, self.encodec.quantizer.bins * len(codes), self.encodec.quantizer.bins, device=audio.device
)
embeddings_idxs = codes + offsets.view(-1, 1, 1)
features = torch.nn.functional.embedding(embeddings_idxs, self.codebook_weights).sum(dim=0)
return features.transpose(1, 2)
class xCodecFeatures(FeatureExtractor):
def __init__(
self,
config: str,
ckpt: str,
):
super().__init__()
self.config = OmegaConf.load(config)
self.model = eval(self.config.generator.name)(**self.config.generator.config)
parameter_dict = torch.load(ckpt, map_location='cpu')
self.model.load_state_dict(parameter_dict['codec_model'])
self.resampler = torchaudio.transforms.Resample(orig_freq=44100, new_freq=16000).to('cuda')
self.model.eval()
def forward(self, audio: torch.Tensor):
# resample audio from 44100 to 16000
audio = self.resampler(audio)
with torch.no_grad():
codes = self.model.encode(audio, target_bw=6)
return codes