File size: 3,441 Bytes
33e3a91
 
 
 
4f821f0
33e3a91
 
 
 
 
 
 
 
 
 
 
 
73e61ac
 
 
33e3a91
 
73e61ac
 
f7db087
73e61ac
33e3a91
 
b7e88e1
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
04d9b94
33e3a91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96c45a5
 
9ca6f22
 
 
 
 
96c45a5
04d9b94
33e3a91
5d75e92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import json
from tqdm import tqdm
from copy import deepcopy

import soundfile as sf
import numpy as np
import gradio as gr
import torch

import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

from util import print_size, sampling
from network import CleanUNet
import torchaudio
import torchaudio.transforms as T

SAMPLE_RATE = 22050

def load_simple(filename):
    wav, sr = torchaudio.load(filename)
    resampler = T.Resample(sr, SAMPLE_RATE, dtype=wav.dtype)
    resampled_wav = resampler(wav)
    return resampled_wav

CONFIG = "configs/DNS-large-full.json"
CHECKPOINT = "./exp/DNS-large-full/checkpoint/pretrained.pkl"

# Parse configs. Globals nicer in this case
with open(CONFIG) as f:
    data = f.read()
    config = json.loads(data)
    gen_config              = config["gen_config"]
    global network_config
    network_config          = config["network_config"]      # to define wavenet
    global train_config
    train_config            = config["train_config"]        # train config
    global trainset_config
    trainset_config         = config["trainset_config"]     # to read trainset configurations

def denoise(filename, ckpt_path = CHECKPOINT, out = "out.wav"):
    """
    Denoise audio
    Parameters:
    output_directory (str):         save generated speeches to this path
    ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; 
                                    automitically selects the maximum iteration if 'max' is selected
    subset (str):                   training, testing, validation
    dump (bool):                    whether save enhanced (denoised) audio
    """

    # setup local experiment path
    exp_path = train_config["exp_path"]
    print('exp_path:', exp_path)

    # load data
    loader_config = deepcopy(trainset_config)
    loader_config["crop_length_sec"] = 0

    # predefine model
    net = CleanUNet(**network_config)
    print_size(net)

    # load checkpoint
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    net.load_state_dict(checkpoint['model_state_dict'])
    net.eval()

    # inference
    noisy_audio = load_simple(filename)

    with torch.no_grad():
        with torch.cuda.amp.autocast():
            generated_audio = sampling(net, noisy_audio)
            generated_audio = generated_audio[0].squeeze().cpu().numpy()
            sf.write(out, np.ravel(generated_audio), SAMPLE_RATE)

    return out

# audio = gr.inputs.Audio(label = "Audio to denoise", type = 'filepath')
# inputs = [audio]
# outputs = gr.outputs.Audio(label = "Denoised audio", type = 'filepath')

# title = "Speech Denoising in the Waveform Domain with Self-Attention from Nvidia"

# gr.Interface(denoise, inputs, outputs, title=title, enable_queue=True).launch()


mic_transcribe = gr.Interface(
    fn=denoise,
    inputs=[
        gr.inputs.Audio(source="microphone", label="Audio to denoise", type="filepath", optional=True),
    ],
    outputs=gr.outputs.Audio(label = "Denoised audio", type = 'filepath'),
    layout="horizontal",
    #theme="huggingface",
    title="My Demo: Speech enhancement",
    #description=(
    #     "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the"
    #     f" checkpoint [{MODEL_NAME}](https://huggingface.co./{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
    #     " of arbitrary length."
    # ),
    allow_flagging="never",
)

mic_transcribe.launch()