|
|
|
''' |
|
This module contains the AudioProcessor class and related functions for processing audio data. |
|
It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction, |
|
and audio separation. The class is initialized with configuration parameters and can process |
|
audio files using the provided models. |
|
''' |
|
import math |
|
import os |
|
|
|
import librosa |
|
import numpy as np |
|
import torch |
|
from audio_separator.separator import Separator |
|
from einops import rearrange |
|
from transformers import Wav2Vec2FeatureExtractor |
|
|
|
from hallo.models.wav2vec import Wav2VecModel |
|
from hallo.utils.util import resample_audio |
|
|
|
|
|
class AudioProcessor: |
|
""" |
|
AudioProcessor is a class that handles the processing of audio files. |
|
It takes care of preprocessing the audio files, extracting features |
|
using wav2vec models, and separating audio signals if needed. |
|
|
|
:param sample_rate: Sampling rate of the audio file |
|
:param fps: Frames per second for the extracted features |
|
:param wav2vec_model_path: Path to the wav2vec model |
|
:param only_last_features: Whether to only use the last features |
|
:param audio_separator_model_path: Path to the audio separator model |
|
:param audio_separator_model_name: Name of the audio separator model |
|
:param cache_dir: Directory to cache the intermediate results |
|
:param device: Device to run the processing on |
|
""" |
|
def __init__( |
|
self, |
|
sample_rate, |
|
fps, |
|
wav2vec_model_path, |
|
only_last_features, |
|
audio_separator_model_path:str=None, |
|
audio_separator_model_name:str=None, |
|
cache_dir:str='', |
|
device="cuda:0", |
|
) -> None: |
|
self.sample_rate = sample_rate |
|
self.fps = fps |
|
self.device = device |
|
|
|
self.audio_encoder = Wav2VecModel.from_pretrained(wav2vec_model_path, local_files_only=True).to(device=device) |
|
self.audio_encoder.feature_extractor._freeze_parameters() |
|
self.only_last_features = only_last_features |
|
|
|
if audio_separator_model_name is not None: |
|
try: |
|
os.makedirs(cache_dir, exist_ok=True) |
|
except OSError as _: |
|
print("Fail to create the output cache dir.") |
|
self.audio_separator = Separator( |
|
output_dir=cache_dir, |
|
output_single_stem="vocals", |
|
model_file_dir=audio_separator_model_path, |
|
) |
|
self.audio_separator.load_model(audio_separator_model_name) |
|
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model." |
|
else: |
|
self.audio_separator=None |
|
print("Use audio directly without vocals seperator.") |
|
|
|
|
|
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_model_path, local_files_only=True) |
|
|
|
|
|
def preprocess(self, wav_file: str): |
|
""" |
|
Preprocess a WAV audio file by separating the vocals from the background and resampling it to a 16 kHz sample rate. |
|
The separated vocal track is then converted into wav2vec2 for further processing or analysis. |
|
|
|
Args: |
|
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. |
|
|
|
Raises: |
|
RuntimeError: Raises an exception if the WAV file cannot be processed. This could be due to issues |
|
such as file not found, unsupported file format, or errors during the audio processing steps. |
|
|
|
Returns: |
|
torch.tensor: Returns an audio embedding as a torch.tensor |
|
""" |
|
if self.audio_separator is not None: |
|
|
|
|
|
outputs = self.audio_separator.separate(wav_file) |
|
if len(outputs) <= 0: |
|
raise RuntimeError("Audio separate failed.") |
|
|
|
vocal_audio_file = outputs[0] |
|
vocal_audio_name, _ = os.path.splitext(vocal_audio_file) |
|
vocal_audio_file = os.path.join(self.audio_separator.output_dir, vocal_audio_file) |
|
vocal_audio_file = resample_audio(vocal_audio_file, os.path.join(self.audio_separator.output_dir, f"{vocal_audio_name}-16k.wav"), self.sample_rate) |
|
else: |
|
vocal_audio_file=wav_file |
|
|
|
|
|
speech_array, sampling_rate = librosa.load(vocal_audio_file, sr=self.sample_rate) |
|
audio_feature = np.squeeze(self.wav2vec_feature_extractor(speech_array, sampling_rate=sampling_rate).input_values) |
|
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) |
|
|
|
audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device) |
|
audio_feature = audio_feature.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
embeddings = self.audio_encoder(audio_feature, seq_len=seq_len, output_hidden_states=True) |
|
assert len(embeddings) > 0, "Fail to extract audio embedding" |
|
if self.only_last_features: |
|
audio_emb = embeddings.last_hidden_state.squeeze() |
|
else: |
|
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) |
|
audio_emb = rearrange(audio_emb, "b s d -> s b d") |
|
|
|
audio_emb = audio_emb.cpu().detach() |
|
|
|
return audio_emb |
|
|
|
def get_embedding(self, wav_file: str): |
|
"""preprocess wav audio file convert to embeddings |
|
|
|
Args: |
|
wav_file (str): The path to the WAV file to be processed. This file should be accessible and in WAV format. |
|
|
|
Returns: |
|
torch.tensor: Returns an audio embedding as a torch.tensor |
|
""" |
|
speech_array, sampling_rate = librosa.load( |
|
wav_file, sr=self.sample_rate) |
|
assert sampling_rate == 16000, "The audio sample rate must be 16000" |
|
audio_feature = np.squeeze(self.wav2vec_feature_extractor( |
|
speech_array, sampling_rate=sampling_rate).input_values) |
|
seq_len = math.ceil(len(audio_feature) / self.sample_rate * self.fps) |
|
|
|
audio_feature = torch.from_numpy( |
|
audio_feature).float().to(device=self.device) |
|
audio_feature = audio_feature.unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
embeddings = self.audio_encoder( |
|
audio_feature, seq_len=seq_len, output_hidden_states=True) |
|
assert len(embeddings) > 0, "Fail to extract audio embedding" |
|
|
|
if self.only_last_features: |
|
audio_emb = embeddings.last_hidden_state.squeeze() |
|
else: |
|
audio_emb = torch.stack( |
|
embeddings.hidden_states[1:], dim=1).squeeze(0) |
|
audio_emb = rearrange(audio_emb, "b s d -> s b d") |
|
|
|
audio_emb = audio_emb.cpu().detach() |
|
|
|
return audio_emb |
|
|
|
def close(self): |
|
""" |
|
TODO: to be implemented |
|
""" |
|
return self |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, _exc_type, _exc_val, _exc_tb): |
|
self.close() |
|
|