Spaces:
Runtime error
Runtime error
COnfig fix
Browse files
app.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import argparse
|
2 |
import os
|
|
|
3 |
|
4 |
import wget
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
import gradio as gr
|
8 |
|
9 |
-
from
|
10 |
|
11 |
TARGETS = [
|
12 |
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
@@ -31,9 +32,11 @@ if not os.path.exists('default_ckpt.pt'):
|
|
31 |
wget.download(ckpt_url)
|
32 |
|
33 |
# Instantiate model
|
34 |
-
|
35 |
-
|
36 |
-
model
|
|
|
|
|
37 |
model.eval()
|
38 |
|
39 |
def waveformer(audio, label_choices):
|
|
|
1 |
import argparse
|
2 |
import os
|
3 |
+
import json
|
4 |
|
5 |
import wget
|
6 |
import torch
|
7 |
import torchaudio
|
8 |
import gradio as gr
|
9 |
|
10 |
+
from dcc_tf import Net as Waveformer
|
11 |
|
12 |
TARGETS = [
|
13 |
"Acoustic_guitar", "Applause", "Bark", "Bass_drum",
|
|
|
32 |
wget.download(ckpt_url)
|
33 |
|
34 |
# Instantiate model
|
35 |
+
with open('default_config.json') as f:
|
36 |
+
params = json.load(f)
|
37 |
+
model = Waveformer(**params['model_params'])
|
38 |
+
model.load_state_dict(
|
39 |
+
torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict'])
|
40 |
model.eval()
|
41 |
|
42 |
def waveformer(audio, label_choices):
|