marigold334 commited on
Commit
13dbff0
β€’
1 Parent(s): 371ba49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -22,14 +22,14 @@ class TTS:
22
  name = '1038_eunsik_01'
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  check_point = torch.load(last_chpt1)
25
- self.flowgenerator.load_state_dict(check_point['generator'])
26
  self.flowgenerator.decoder.skip()
27
  self.flowgenerator.eval()
28
  if model_variant == '은식':
29
  name = '1038_eunsik_01'
30
  last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
31
  check_point = torch.load(last_chpt2)
32
- self.voicegenerator.load_state_dict(check_point['gen_model'])
33
  self.voicegenerator.eval()
34
  self.voicegenerator.remove_weight_norm()
35
 
 
22
  name = '1038_eunsik_01'
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  check_point = torch.load(last_chpt1)
25
+ self.flowgenerator.load_state_dict(check_point['generator'], map_location = device)
26
  self.flowgenerator.decoder.skip()
27
  self.flowgenerator.eval()
28
  if model_variant == '은식':
29
  name = '1038_eunsik_01'
30
  last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
31
  check_point = torch.load(last_chpt2)
32
+ self.voicegenerator.load_state_dict(check_point['gen_model'], map_location = device)
33
  self.voicegenerator.eval()
34
  self.voicegenerator.remove_weight_norm()
35