bandhav commited on
Commit
de110b1
1 Parent(s): 2db9aa5

COnfig fix

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import argparse
2
  import os
 
3
 
4
  import wget
5
  import torch
6
  import torchaudio
7
  import gradio as gr
8
 
9
- from src.training.dcc_tf import Net as Waveformer
10
 
11
  TARGETS = [
12
  "Acoustic_guitar", "Applause", "Bark", "Bass_drum",
@@ -31,9 +32,11 @@ if not os.path.exists('default_ckpt.pt'):
31
  wget.download(ckpt_url)
32
 
33
  # Instantiate model
34
- params = utils.Params('default_config.json')
35
- model = Waveformer(**params.model_params)
36
- model.load_state_dict(torch.load('default_ckpt.pt', map_location=torch.device('cpu')))
 
 
37
  model.eval()
38
 
39
  def waveformer(audio, label_choices):
 
1
  import argparse
2
  import os
3
+ import json
4
 
5
  import wget
6
  import torch
7
  import torchaudio
8
  import gradio as gr
9
 
10
+ from dcc_tf import Net as Waveformer
11
 
12
  TARGETS = [
13
  "Acoustic_guitar", "Applause", "Bark", "Bass_drum",
 
32
  wget.download(ckpt_url)
33
 
34
  # Instantiate model
35
+ with open('default_config.json') as f:
36
+ params = json.load(f)
37
+ model = Waveformer(**params['model_params'])
38
+ model.load_state_dict(
39
+ torch.load('default_ckpt.pt', map_location=torch.device('cpu'))['model_state_dict'])
40
  model.eval()
41
 
42
  def waveformer(audio, label_choices):