Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import wget | |
import torch | |
import torchaudio | |
import gradio as gr | |
from src.training.dcc_tf import Net as Waveformer | |
TARGETS = [ | |
"Acoustic_guitar", "Applause", "Bark", "Bass_drum", | |
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", | |
"Computer_keyboard", "Cough", "Cowbell", "Double_bass", | |
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", | |
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", | |
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", | |
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", | |
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", | |
"Trumpet", "Violin_or_fiddle", "Writing" | |
] | |
if not os.path.exists('default_config.json'): | |
config_url = 'https://targetsound.cs.washington.edu/files/default_config.json' | |
print("Downloading model configuration from %s:" % config_url) | |
wget.download(config_url) | |
if not os.path.exists('default_ckpt.pt'): | |
ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt' | |
print("\nDownloading the checkpoint from %s:" % ckpt_url) | |
wget.download(ckpt_url) | |
# Instantiate model | |
params = utils.Params('default_config.json') | |
model = Waveformer(**params.model_params) | |
model.load_state_dict(torch.load('default_ckpt.pt', map_location=torch.device('cpu'))) | |
model.eval() | |
def waveformer(audio, label_choices): | |
# Read input audio | |
fs, mixture = audio | |
if fs != 44100: | |
raise ValueError(fs) | |
mixture = torch.from_numpy( | |
mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15) | |
# Construct the query vector | |
if len(label_choices) == 0: | |
raise ValueError(label_choices) | |
query = torch.zeros(1, len(TARGETS)) | |
for t in label_choices: | |
query[0, TARGETS.index(t)] = 1. | |
with torch.no_grad(): | |
output = (2.0 ** 15) * model(mixture, query) | |
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() | |
label_checkbox = gr.CheckboxGroup(choices=TARGETS) | |
demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio") | |
demo.launch() | |