Spaces:
Runtime error
Runtime error
File size: 2,099 Bytes
e6a6383 7089999 e6a6383 7089999 e6a6383 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import argparse
import os
import wget
import torch
import torchaudio
import gradio as gr
from src.helpers import utils
from src.training.dcc_tf import Net as Waveformer
TARGETS = [
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
"Computer_keyboard", "Cough", "Cowbell", "Double_bass",
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
"Trumpet", "Violin_or_fiddle", "Writing"
]
if not os.path.exists('default_config.json'):
config_url = 'https://targetsound.cs.washington.edu/files/default_config.json'
print("Downloading model configuration from %s:" % config_url)
wget.download(config_url)
if not os.path.exists('default_ckpt.pt'):
ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt'
print("\nDownloading the checkpoint from %s:" % ckpt_url)
wget.download(ckpt_url)
# Instantiate model
params = utils.Params('default_config.json')
model = Waveformer(**params.model_params)
utils.load_checkpoint('default_ckpt.pt', model)
model.eval()
def waveformer(audio, label_choices):
# Read input audio
fs, mixture = audio
if fs != 44100:
raise ValueError(fs)
mixture = torch.from_numpy(
mixture).unsqueeze(0).unsqueeze(0).to(torch.float)
# Construct the query vector
if len(label_choices) == 0:
raise ValueError(label_choices)
query = torch.zeros(1, len(TARGETS))
for t in label_choices:
query[0, TARGETS.index(t)] = 1.
with torch.no_grad():
output = model(mixture, query)
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy()
label_checkbox = gr.CheckboxGroup(choices=TARGETS)
demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio")
demo.launch()
|