#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse import logging import os from pathlib import Path import sys import uuid pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import librosa import numpy as np import pandas as pd from scipy.io import wavfile import torch import torch.nn as nn import torchaudio from tqdm import tqdm from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--model_dir", default="serialization_dir/best", type=str) parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) parser.add_argument("--limit", default=10, type=int) args = parser.parse_args() return args def logging_config(): fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" logging.basicConfig(format=fmt, datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO) stream_handler = logging.StreamHandler() stream_handler.setLevel(logging.INFO) stream_handler.setFormatter(logging.Formatter(fmt)) logger = logging.getLogger(__name__) return logger def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): l1 = len(speech) l2 = len(noise) l = min(l1, l2) speech = speech[:l] noise = noise[:l] # np.float32, value between (-1, 1). speech_power = np.mean(np.square(speech)) noise_power = speech_power / (10 ** (snr_db / 10)) noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) noisy_signal = speech + noise_adjusted return noisy_signal stft_power = torchaudio.transforms.Spectrogram( n_fft=512, win_length=200, hop_length=80, power=2.0, window_fn=torch.hamming_window, ) stft_complex = torchaudio.transforms.Spectrogram( n_fft=512, win_length=200, hop_length=80, power=None, window_fn=torch.hamming_window, ) istft = torchaudio.transforms.InverseSpectrogram( n_fft=512, win_length=200, hop_length=80, window_fn=torch.hamming_window, ) def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): mix_spec_complex = mix_spec_complex.detach().cpu() speech_irm_prediction = speech_irm_prediction.detach().cpu() mask_speech = speech_irm_prediction mask_noise = 1.0 - speech_irm_prediction speech_spec = mix_spec_complex * mask_speech noise_spec = mix_spec_complex * mask_noise speech_wave = istft.forward(speech_spec) noise_wave = istft.forward(noise_spec) return speech_wave, noise_wave def save_audios(noise_wave: torch.Tensor, speech_wave: torch.Tensor, mix_wave: torch.Tensor, speech_wave_enhanced: torch.Tensor, noise_wave_enhanced: torch.Tensor, output_dir: str, sample_rate: int = 8000, ): basename = uuid.uuid4().__str__() output_dir = Path(output_dir) / basename output_dir.mkdir(parents=True, exist_ok=True) filename = output_dir / "noise_wave.wav" torchaudio.save(filename, noise_wave, sample_rate) filename = output_dir / "speech_wave.wav" torchaudio.save(filename, speech_wave, sample_rate) filename = output_dir / "mix_wave.wav" torchaudio.save(filename, mix_wave, sample_rate) filename = output_dir / "speech_wave_enhanced.wav" torchaudio.save(filename, speech_wave_enhanced, sample_rate) filename = output_dir / "noise_wave_enhanced.wav" torchaudio.save(filename, noise_wave_enhanced, sample_rate) return output_dir.as_posix() def main(): args = get_args() logger = logging_config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) logger.info("prepare model") model = SimpleLinearIRMPretrainedModel.from_pretrained( pretrained_model_name_or_path=args.model_dir, ) model.to(device) model.eval() # optimizer logger.info("prepare loss_fn") mse_loss = nn.MSELoss( reduction="mean", ) logger.info("read excel") df = pd.read_excel(args.valid_dataset) total_loss = 0. total_examples = 0. progress_bar = tqdm(total=len(df), desc="Evaluation") for idx, row in df.iterrows(): noise_filename = row["noise_filename"] noise_offset = row["noise_offset"] noise_duration = row["noise_duration"] speech_filename = row["speech_filename"] speech_offset = row["speech_offset"] speech_duration = row["speech_duration"] snr_db = row["snr_db"] noise_wave, _ = librosa.load( noise_filename, sr=8000, offset=noise_offset, duration=noise_duration, ) speech_wave, _ = librosa.load( speech_filename, sr=8000, offset=speech_offset, duration=speech_duration, ) mix_wave: np.ndarray = mix_speech_and_noise( speech=speech_wave, noise=noise_wave, snr_db=snr_db, ) noise_wave = torch.tensor(noise_wave, dtype=torch.float32) speech_wave = torch.tensor(speech_wave, dtype=torch.float32) mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) noise_wave = noise_wave.unsqueeze(dim=0) speech_wave = speech_wave.unsqueeze(dim=0) mix_wave = mix_wave.unsqueeze(dim=0) noise_spec: torch.Tensor = stft_power.forward(noise_wave) speech_spec: torch.Tensor = stft_power.forward(speech_wave) mix_spec: torch.Tensor = stft_power.forward(mix_wave) mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) speech_irm = speech_spec / (noise_spec + speech_spec) speech_irm = torch.pow(speech_irm, 1.0) mix_spec = mix_spec.to(device) speech_irm_target = speech_irm.to(device) with torch.no_grad(): speech_irm_prediction = model.forward(mix_spec) loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) total_loss += loss.item() total_examples += mix_spec.size(0) evaluation_loss = total_loss / total_examples evaluation_loss = round(evaluation_loss, 4) progress_bar.update(1) progress_bar.set_postfix({ "evaluation_loss": evaluation_loss, }) if idx > args.limit: break return if __name__ == '__main__': main()