|
import os |
|
import json |
|
from tqdm import tqdm |
|
from copy import deepcopy |
|
|
|
import numpy as np |
|
import gradio as gr |
|
import torch |
|
|
|
import random |
|
random.seed(0) |
|
torch.manual_seed(0) |
|
np.random.seed(0) |
|
|
|
from scipy.io.wavfile import write as wavwrite |
|
|
|
from util import print_size, sampling |
|
from network import CleanUNet |
|
import torchaudio |
|
|
|
def load_simple(filename): |
|
audio, _ = torchaudio.load(filename) |
|
return audio |
|
|
|
CONFIG = "configs/DNS-large-full.json" |
|
CHECKPOINT = "./exp/DNS-large-high/checkpoint/pretrained.pkl" |
|
|
|
|
|
with open(CONFIG) as f: |
|
data = f.read() |
|
config = json.loads(data) |
|
gen_config = config["gen_config"] |
|
global network_config |
|
network_config = config["network_config"] |
|
global train_config |
|
train_config = config["train_config"] |
|
global trainset_config |
|
trainset_config = config["trainset_config"] |
|
|
|
def denoise(files, ckpt_path): |
|
""" |
|
Denoise audio |
|
Parameters: |
|
output_directory (str): save generated speeches to this path |
|
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded; |
|
automitically selects the maximum iteration if 'max' is selected |
|
subset (str): training, testing, validation |
|
dump (bool): whether save enhanced (denoised) audio |
|
""" |
|
|
|
|
|
exp_path = train_config["exp_path"] |
|
print('exp_path:', exp_path) |
|
|
|
|
|
loader_config = deepcopy(trainset_config) |
|
loader_config["crop_length_sec"] = 0 |
|
|
|
|
|
net = CleanUNet(**network_config) |
|
print_size(net) |
|
|
|
|
|
checkpoint = torch.load(ckpt_path, map_location='cpu') |
|
net.load_state_dict(checkpoint['model_state_dict']) |
|
net.eval() |
|
|
|
|
|
batch_size = 1000000 |
|
for file_path in tqdm(files): |
|
file_name = os.path.basename(file_path) |
|
file_dir = os.path.dirname(file_name) |
|
new_file_name = file_name + "_denoised.wav" |
|
noisy_audio = load_simple(file_path) |
|
LENGTH = len(noisy_audio[0].squeeze()) |
|
noisy_audio = torch.chunk(noisy_audio, LENGTH // batch_size + 1, dim=1) |
|
all_audio = [] |
|
|
|
for batch in tqdm(noisy_audio): |
|
with torch.no_grad(): |
|
generated_audio = sampling(net, batch) |
|
generated_audio = generated_audio.cpu().numpy().squeeze() |
|
all_audio.append(generated_audio) |
|
|
|
all_audio = np.concatenate(all_audio, axis=0) |
|
save_file = os.path.join(file_dir, new_file_name) |
|
print("saved to:", save_file) |
|
wavwrite(save_file, 32000, all_audio.squeeze()) |
|
|
|
|
|
audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath') |
|
inputs = [audio, CHECKPOINT] |
|
outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath') |
|
|
|
title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia" |
|
|
|
gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch() |