import argparse import os import wget import torch import torchaudio import gradio as gr from src.helpers import utils from src.training.dcc_tf import Net as Waveformer TARGETS = [ "Acoustic_guitar", "Applause", "Bark", "Bass_drum", "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet", "Computer_keyboard", "Cough", "Cowbell", "Double_bass", "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping", "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire", "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow", "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter", "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone", "Trumpet", "Violin_or_fiddle", "Writing" ] if not os.path.exists('default_config.json'): config_url = 'https://targetsound.cs.washington.edu/files/default_config.json' print("Downloading model configuration from %s:" % config_url) wget.download(config_url) if not os.path.exists('default_ckpt.pt'): ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt' print("\nDownloading the checkpoint from %s:" % ckpt_url) wget.download(ckpt_url) # Instantiate model params = utils.Params('default_config.json') model = Waveformer(**params.model_params) utils.load_checkpoint('default_ckpt.pt', model) model.eval() def waveformer(audio, label_choices): # Read input audio fs, mixture = audio if fs != 44100: raise ValueError(fs) mixture = torch.from_numpy( mixture).unsqueeze(0).unsqueeze(0).to(torch.float) # Construct the query vector if len(label_choices) == 0: raise ValueError(label_choices) query = torch.zeros(1, len(TARGETS)) for t in label_choices: query[0, TARGETS.index(t)] = 1. with torch.no_grad(): output = model(mixture, query) return fs, output.squeeze(0).squeeze(0).to(torch.short).numpy() label_checkbox = gr.CheckboxGroup(choices=TARGETS) demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio") demo.launch()