denoiser / app.py
van-ng's picture
Update app.py
5d75e92 verified
raw
history blame
No virus
3.44 kB
import json
from tqdm import tqdm
from copy import deepcopy
import soundfile as sf
import numpy as np
import gradio as gr
import torch
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
from util import print_size, sampling
from network import CleanUNet
import torchaudio
import torchaudio.transforms as T
SAMPLE_RATE = 22050
def load_simple(filename):
wav, sr = torchaudio.load(filename)
resampler = T.Resample(sr, SAMPLE_RATE, dtype=wav.dtype)
resampled_wav = resampler(wav)
return resampled_wav
CONFIG = "configs/DNS-large-full.json"
CHECKPOINT = "./exp/DNS-large-full/checkpoint/pretrained.pkl"
# Parse configs. Globals nicer in this case
with open(CONFIG) as f:
data = f.read()
config = json.loads(data)
gen_config = config["gen_config"]
global network_config
network_config = config["network_config"] # to define wavenet
global train_config
train_config = config["train_config"] # train config
global trainset_config
trainset_config = config["trainset_config"] # to read trainset configurations
def denoise(filename, ckpt_path = CHECKPOINT, out = "out.wav"):
"""
Denoise audio
Parameters:
output_directory (str): save generated speeches to this path
ckpt_iter (int or 'max'): the pretrained checkpoint to be loaded;
automitically selects the maximum iteration if 'max' is selected
subset (str): training, testing, validation
dump (bool): whether save enhanced (denoised) audio
"""
# setup local experiment path
exp_path = train_config["exp_path"]
print('exp_path:', exp_path)
# load data
loader_config = deepcopy(trainset_config)
loader_config["crop_length_sec"] = 0
# predefine model
net = CleanUNet(**network_config)
print_size(net)
# load checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
# inference
noisy_audio = load_simple(filename)
with torch.no_grad():
with torch.cuda.amp.autocast():
generated_audio = sampling(net, noisy_audio)
generated_audio = generated_audio[0].squeeze().cpu().numpy()
sf.write(out, np.ravel(generated_audio), SAMPLE_RATE)
return out
# audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
# inputs = [audio]
# outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')
# title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia"
# gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch()
mic_transcribe = gr.Interface(
fn=denoise,
inputs=[
gr.inputs.Audio(source="microphone", label="Audio to denoise", type="filepath", optional=True),
],
outputs=gr.outputs.Audio(label = "Denoised audio", type = 'filepath'),
layout="horizontal",
#theme="huggingface",
title="My Demo: Speech enhancement",
#description=(
# "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
# f" checkpoint [{MODEL_NAME}](https://huggingface.co./{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
# " of arbitrary length."
# ),
allow_flagging="never",
)
mic_transcribe.launch()