File size: 2,099 Bytes
e6a6383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7089999
 
e6a6383
 
 
 
 
 
 
 
 
 
 
7089999
e6a6383
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
import os

import wget
import torch
import torchaudio
import gradio as gr

from src.helpers import utils
from src.training.dcc_tf import Net as Waveformer

TARGETS = [
    "Acoustic_guitar", "Applause", "Bark", "Bass_drum",
    "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
    "Computer_keyboard", "Cough", "Cowbell", "Double_bass",
    "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
    "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
    "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
    "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
    "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
    "Trumpet", "Violin_or_fiddle", "Writing"
]

if not os.path.exists('default_config.json'):
    config_url = 'https://targetsound.cs.washington.edu/files/default_config.json'
    print("Downloading model configuration from %s:" % config_url)
    wget.download(config_url)

if not os.path.exists('default_ckpt.pt'):
    ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt'
    print("\nDownloading the checkpoint from %s:" % ckpt_url)
    wget.download(ckpt_url)

# Instantiate model
params = utils.Params('default_config.json')
model = Waveformer(**params.model_params)
utils.load_checkpoint('default_ckpt.pt', model)
model.eval()

def waveformer(audio, label_choices):
    # Read input audio
    fs, mixture = audio
    if fs != 44100:
        raise ValueError(fs)
    mixture = torch.from_numpy(
        mixture).unsqueeze(0).unsqueeze(0).to(torch.float)

    # Construct the query vector
    if len(label_choices) == 0:
         raise ValueError(label_choices)
    query = torch.zeros(1, len(TARGETS))
    for t in label_choices:
        query[0, TARGETS.index(t)] = 1.

    with torch.no_grad():
        output = model(mixture, query)

    return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()


label_checkbox = gr.CheckboxGroup(choices=TARGETS)
demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio")
demo.launch()