HoneyTian's picture
first commit
bd94e77
raw
history blame
7.06 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import logging
import os
from pathlib import Path
import sys
import uuid
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import librosa
import numpy as np
import pandas as pd
from scipy.io import wavfile
import torch
import torch.nn as nn
import torchaudio
from tqdm import tqdm
from toolbox.torchaudio.models.simple_linear_irm.modeling_simple_linear_irm import SimpleLinearIRMPretrainedModel
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--model_dir", default="serialization_dir/best", type=str)
parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str)
parser.add_argument("--limit", default=10, type=int)
args = parser.parse_args()
return args
def logging_config():
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
logging.basicConfig(format=fmt,
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(logging.Formatter(fmt))
logger = logging.getLogger(__name__)
return logger
def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float):
l1 = len(speech)
l2 = len(noise)
l = min(l1, l2)
speech = speech[:l]
noise = noise[:l]
# np.float32, value between (-1, 1).
speech_power = np.mean(np.square(speech))
noise_power = speech_power / (10 ** (snr_db / 10))
noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2))
noisy_signal = speech + noise_adjusted
return noisy_signal
stft_power = torchaudio.transforms.Spectrogram(
n_fft=512,
win_length=200,
hop_length=80,
power=2.0,
window_fn=torch.hamming_window,
)
stft_complex = torchaudio.transforms.Spectrogram(
n_fft=512,
win_length=200,
hop_length=80,
power=None,
window_fn=torch.hamming_window,
)
istft = torchaudio.transforms.InverseSpectrogram(
n_fft=512,
win_length=200,
hop_length=80,
window_fn=torch.hamming_window,
)
def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor):
mix_spec_complex = mix_spec_complex.detach().cpu()
speech_irm_prediction = speech_irm_prediction.detach().cpu()
mask_speech = speech_irm_prediction
mask_noise = 1.0 - speech_irm_prediction
speech_spec = mix_spec_complex * mask_speech
noise_spec = mix_spec_complex * mask_noise
speech_wave = istft.forward(speech_spec)
noise_wave = istft.forward(noise_spec)
return speech_wave, noise_wave
def save_audios(noise_wave: torch.Tensor,
speech_wave: torch.Tensor,
mix_wave: torch.Tensor,
speech_wave_enhanced: torch.Tensor,
noise_wave_enhanced: torch.Tensor,
output_dir: str,
sample_rate: int = 8000,
):
basename = uuid.uuid4().__str__()
output_dir = Path(output_dir) / basename
output_dir.mkdir(parents=True, exist_ok=True)
filename = output_dir / "noise_wave.wav"
torchaudio.save(filename, noise_wave, sample_rate)
filename = output_dir / "speech_wave.wav"
torchaudio.save(filename, speech_wave, sample_rate)
filename = output_dir / "mix_wave.wav"
torchaudio.save(filename, mix_wave, sample_rate)
filename = output_dir / "speech_wave_enhanced.wav"
torchaudio.save(filename, speech_wave_enhanced, sample_rate)
filename = output_dir / "noise_wave_enhanced.wav"
torchaudio.save(filename, noise_wave_enhanced, sample_rate)
return output_dir.as_posix()
def main():
args = get_args()
logger = logging_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
logger.info("prepare model")
model = SimpleLinearIRMPretrainedModel.from_pretrained(
pretrained_model_name_or_path=args.model_dir,
)
model.to(device)
model.eval()
# optimizer
logger.info("prepare loss_fn")
mse_loss = nn.MSELoss(
reduction="mean",
)
logger.info("read excel")
df = pd.read_excel(args.valid_dataset)
total_loss = 0.
total_examples = 0.
progress_bar = tqdm(total=len(df), desc="Evaluation")
for idx, row in df.iterrows():
noise_filename = row["noise_filename"]
noise_offset = row["noise_offset"]
noise_duration = row["noise_duration"]
speech_filename = row["speech_filename"]
speech_offset = row["speech_offset"]
speech_duration = row["speech_duration"]
snr_db = row["snr_db"]
noise_wave, _ = librosa.load(
noise_filename,
sr=8000,
offset=noise_offset,
duration=noise_duration,
)
speech_wave, _ = librosa.load(
speech_filename,
sr=8000,
offset=speech_offset,
duration=speech_duration,
)
mix_wave: np.ndarray = mix_speech_and_noise(
speech=speech_wave,
noise=noise_wave,
snr_db=snr_db,
)
noise_wave = torch.tensor(noise_wave, dtype=torch.float32)
speech_wave = torch.tensor(speech_wave, dtype=torch.float32)
mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32)
noise_wave = noise_wave.unsqueeze(dim=0)
speech_wave = speech_wave.unsqueeze(dim=0)
mix_wave = mix_wave.unsqueeze(dim=0)
noise_spec: torch.Tensor = stft_power.forward(noise_wave)
speech_spec: torch.Tensor = stft_power.forward(speech_wave)
mix_spec: torch.Tensor = stft_power.forward(mix_wave)
mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave)
speech_irm = speech_spec / (noise_spec + speech_spec)
speech_irm = torch.pow(speech_irm, 1.0)
mix_spec = mix_spec.to(device)
speech_irm_target = speech_irm.to(device)
with torch.no_grad():
speech_irm_prediction = model.forward(mix_spec)
loss = mse_loss.forward(speech_irm_prediction, speech_irm_target)
speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction)
save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir)
total_loss += loss.item()
total_examples += mix_spec.size(0)
evaluation_loss = total_loss / total_examples
evaluation_loss = round(evaluation_loss, 4)
progress_bar.update(1)
progress_bar.set_postfix({
"evaluation_loss": evaluation_loss,
})
if idx > args.limit:
break
return
if __name__ == '__main__':
main()