Spaces:
Runtime error
Runtime error
import argparse | |
import os | |
import json | |
import wget | |
import torch | |
import torchaudio | |
import gradio as gr | |
from dcc_tf import Net as Waveformer | |
TARGETS = [ | |
"Acoustic_guitar", "Applause", "Bark", "Bass_drum", | |
"Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", | |
"Computer_keyboard", "Cough", "Cowbell", "Double_bass", | |
"Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", | |
"Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", | |
"Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", | |
"Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", | |
"Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", | |
"Trumpet", "Violin_or_fiddle", "Writing" | |
] | |
if not os.path.exists('default_config.json'): | |
config_url = 'https://targetsound.cs.washington.edu/files/default_config.json' | |
print("Downloading model configuration from %s:" % config_url) | |
wget.download(config_url) | |
if not os.path.exists('default_ckpt.pt'): | |
ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt' | |
print("\nDownloading the checkpoint from %s:" % ckpt_url) | |
wget.download(ckpt_url) | |
# Instantiate model | |
with open('default_config.json') as f: | |
params = json.load(f) | |
model = Waveformer(**params['model_params']) | |
model.load_state_dict( | |
torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict']) | |
model.eval() | |
def waveformer(audio, label_choices): | |
# Read input audio | |
fs, mixture = audio | |
if fs != 44100: | |
raise ValueError("Sampling rate must be 44100, but got %d" % fs) | |
mixture = torch.from_numpy( | |
mixture).unsqueeze(0).unsqueeze(0).to(torch.float) / (2.0 ** 15) | |
# Construct the query vector | |
query = torch.zeros(1, len(TARGETS)) | |
for t in label_choices: | |
query[0, TARGETS.index(t)] = 1. | |
with torch.no_grad(): | |
output = (2.0 ** 15) * model(mixture, query) | |
return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() | |
input_audio = gr.Audio(label="Input audio") | |
label_checkbox = gr.CheckboxGroup(choices=TARGETS, label="Input target selection(s)") | |
output_audio = gr.Audio(label="Output audio") | |
demo = gr.Interface(fn=waveformer, inputs=[input_audio, label_checkbox], outputs=output_audio) | |
demo.launch(show_error=True) | |