Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/yxlu-0102/MP-SENet/blob/main/inference.py | |
""" | |
import argparse | |
import logging | |
import os | |
from pathlib import Path | |
import sys | |
import uuid | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import librosa | |
import numpy as np | |
import pandas as pd | |
from scipy.io import wavfile | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
from tqdm import tqdm | |
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig | |
from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel | |
from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
parser.add_argument("--model_dir", default="serialization_dir/best", type=str) | |
parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) | |
parser.add_argument("--limit", default=10, type=int) | |
args = parser.parse_args() | |
return args | |
def logging_config(): | |
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
logging.basicConfig(format=fmt, | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO) | |
stream_handler = logging.StreamHandler() | |
stream_handler.setLevel(logging.INFO) | |
stream_handler.setFormatter(logging.Formatter(fmt)) | |
logger = logging.getLogger(__name__) | |
return logger | |
def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): | |
l1 = len(speech) | |
l2 = len(noise) | |
l = min(l1, l2) | |
speech = speech[:l] | |
noise = noise[:l] | |
# np.float32, value between (-1, 1). | |
speech_power = np.mean(np.square(speech)) | |
noise_power = speech_power / (10 ** (snr_db / 10)) | |
noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) | |
noisy_signal = speech + noise_adjusted | |
return noisy_signal | |
def save_audios(noise_audio: torch.Tensor, | |
clean_audio: torch.Tensor, | |
noisy_audio: torch.Tensor, | |
enhanced_audio: torch.Tensor, | |
output_dir: str, | |
sample_rate: int = 8000, | |
): | |
basename = uuid.uuid4().__str__() | |
output_dir = Path(output_dir) / basename | |
output_dir.mkdir(parents=True, exist_ok=True) | |
filename = output_dir / "noise_audio.wav" | |
torchaudio.save(filename, noise_audio.detach().cpu(), sample_rate, bits_per_sample=16) | |
filename = output_dir / "clean_audio.wav" | |
torchaudio.save(filename, clean_audio.detach().cpu(), sample_rate, bits_per_sample=16) | |
filename = output_dir / "noisy_audio.wav" | |
torchaudio.save(filename, noisy_audio.detach().cpu(), sample_rate, bits_per_sample=16) | |
filename = output_dir / "enhanced_audio.wav" | |
torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate, bits_per_sample=16) | |
return output_dir.as_posix() | |
def main(): | |
args = get_args() | |
logger = logging_config() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
n_gpu = torch.cuda.device_count() | |
logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) | |
logger.info("prepare model") | |
config = MPNetConfig.from_pretrained( | |
pretrained_model_name_or_path=args.model_dir, | |
) | |
generator = MPNetPretrainedModel.from_pretrained( | |
pretrained_model_name_or_path=args.model_dir, | |
) | |
generator.to(device) | |
generator.eval() | |
logger.info("read excel") | |
df = pd.read_excel(args.valid_dataset) | |
progress_bar = tqdm(total=len(df), desc="Evaluation") | |
for idx, row in df.iterrows(): | |
noise_filename = row["noise_filename"] | |
noise_offset = row["noise_offset"] | |
noise_duration = row["noise_duration"] | |
speech_filename = row["speech_filename"] | |
speech_offset = row["speech_offset"] | |
speech_duration = row["speech_duration"] | |
snr_db = row["snr_db"] | |
noise_audio, _ = librosa.load( | |
noise_filename, | |
sr=8000, | |
offset=noise_offset, | |
duration=noise_duration, | |
) | |
clean_audio, _ = librosa.load( | |
speech_filename, | |
sr=8000, | |
offset=speech_offset, | |
duration=speech_duration, | |
) | |
noisy_audio: np.ndarray = mix_speech_and_noise( | |
speech=clean_audio, | |
noise=noise_audio, | |
snr_db=snr_db, | |
) | |
noise_audio = torch.tensor(noise_audio, dtype=torch.float32) | |
clean_audio = torch.tensor(clean_audio, dtype=torch.float32) | |
noisy_audio: torch.Tensor = torch.tensor(noisy_audio, dtype=torch.float32) | |
noise_audio = noise_audio.unsqueeze(dim=0) | |
clean_audio = clean_audio.unsqueeze(dim=0) | |
noisy_audio: torch.Tensor = noisy_audio.unsqueeze(dim=0) | |
# inference | |
clean_audio = clean_audio.to(device) | |
noisy_audio = noisy_audio.to(device) | |
with torch.no_grad(): | |
noisy_mag, noisy_pha, noisy_com = mag_pha_stft( | |
noisy_audio, config.n_fft, config.hop_size, config.win_size, config.compress_factor | |
) | |
mag_g, pha_g, com_g = generator.forward(noisy_mag, noisy_pha) | |
audio_g = mag_pha_istft( | |
mag_g, pha_g, config.n_fft, config.hop_size, config.win_size, config.compress_factor | |
) | |
enhanced_audio = audio_g.detach() | |
save_audios( | |
noise_audio, clean_audio, noisy_audio, | |
enhanced_audio, | |
args.evaluation_audio_dir | |
) | |
progress_bar.update(1) | |
if idx > args.limit: | |
break | |
return | |
if __name__ == '__main__': | |
main() | |