marigold334 commited on
Commit
d8c4b79
โ€ข
1 Parent(s): bb04426
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -23,17 +23,17 @@ class TTS:
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  elif model_variant == 'KSS':
25
  last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
26
- elif model_variant == 'ํƒœ์—ฐ':
27
  last_chpt1 = './log/Taeyeon/Glow_TTS_337000.pt'
28
  check_point = torch.load(last_chpt1, map_location = device)
29
- self.flowgenerator.load_state_dict(check_point['generator' if model_variant != 'ํƒœ์—ฐ' else 'model'])
30
- self.flowgenerator.decoder.skip() if model_variant != 'ํƒœ์—ฐ' else None
31
  self.flowgenerator.eval()
32
  if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
33
- last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
34
  elif model_variant == 'KSS':
35
  last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
36
- elif model_variant == 'ํƒœ์—ฐ':
37
  last_chpt2 = './log/Taeyeon/HiFi_GAN_400000.pt'
38
  check_point = torch.load(last_chpt2, map_location = device)
39
  self.voicegenerator.load_state_dict(check_point['gen_model'])
@@ -41,7 +41,7 @@ class TTS:
41
  self.voicegenerator.remove_weight_norm()
42
 
43
  def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
44
- filters = '([.,!?])' if st.session_state != 'ํƒœ์—ฐ' else '([,])'
45
  sentence = re.sub(re.compile(filters), '', input_text)
46
  x = text_to_sequence(sentence)
47
  x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
@@ -57,16 +57,16 @@ class TTS:
57
  def init_session_state():
58
  if "init_model" not in st.session_state:
59
  st.session_state.init_model = True
60
- st.session_state.model_variant = "ํƒœ์—ฐ"
61
- st.session_state.TTS = TTS("ํƒœ์—ฐ")
62
 
63
  def update_model():
64
  if st.session_state.model_variant == "KSS":
65
  st.session_state.TTS = TTS("KSS")
66
  elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹":
67
  st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹")
68
- elif st.session_state.model_variant == 'ํƒœ์—ฐ':
69
- st.session_state.TTS = TTS("ํƒœ์—ฐ")
70
 
71
  def update_session_state(state_id, state_value):
72
  st.session_state[f"{state_id}"] = state_value
@@ -97,7 +97,7 @@ with col1:
97
  value = "๋ฐฅ์€ ๋จน๊ณ  ๋‹ค๋…€?",
98
  )
99
  with col2:
100
- model_variant = st.selectbox("๋ชฉ์†Œ๋ฆฌ ์„ ํƒํ•ด์ฃผ์„ธ์š”", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹", "ํƒœ์—ฐ"], index = 1)
101
 
102
  button_change = st.button("Change Vocie")
103
  if button_change == True:
@@ -107,8 +107,8 @@ if button_change == True:
107
  update_model()
108
  st.success('Done!', icon="โœ…")
109
 
110
- noise_scale = st.slider('noise๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 0.33, step = 0.01)
111
- length_scale = st.slider('์†๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 1., step = 0.01)
112
 
113
  button_gen = st.button("Generate Voice")
114
  if button_gen == True:
 
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  elif model_variant == 'KSS':
25
  last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
26
+ elif model_variant == '์ˆ ์ทจํ•œ ํƒœ์—ฐ':
27
  last_chpt1 = './log/Taeyeon/Glow_TTS_337000.pt'
28
  check_point = torch.load(last_chpt1, map_location = device)
29
+ self.flowgenerator.load_state_dict(check_point['generator' if model_variant != '์ˆ ์ทจํ•œ ํƒœ์—ฐ' else 'model'])
30
+ self.flowgenerator.decoder.skip() if model_variant != '์ˆ ์ทจํ•œ ํƒœ์—ฐ' else None
31
  self.flowgenerator.eval()
32
  if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
33
+ last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00664000.pt'
34
  elif model_variant == 'KSS':
35
  last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
36
+ elif model_variant == '์ˆ ์ทจํ•œ ํƒœ์—ฐ':
37
  last_chpt2 = './log/Taeyeon/HiFi_GAN_400000.pt'
38
  check_point = torch.load(last_chpt2, map_location = device)
39
  self.voicegenerator.load_state_dict(check_point['gen_model'])
 
41
  self.voicegenerator.remove_weight_norm()
42
 
43
  def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
44
+ filters = '([.,!?])' if st.session_state != '์ˆ ์ทจํ•œ ํƒœ์—ฐ' else '([,])'
45
  sentence = re.sub(re.compile(filters), '', input_text)
46
  x = text_to_sequence(sentence)
47
  x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
 
57
  def init_session_state():
58
  if "init_model" not in st.session_state:
59
  st.session_state.init_model = True
60
+ st.session_state.model_variant = "์ˆ ์ทจํ•œ ํƒœ์—ฐ"
61
+ st.session_state.TTS = TTS("์ˆ ์ทจํ•œ ํƒœ์—ฐ")
62
 
63
  def update_model():
64
  if st.session_state.model_variant == "KSS":
65
  st.session_state.TTS = TTS("KSS")
66
  elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹":
67
  st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹")
68
+ elif st.session_state.model_variant == '์ˆ ์ทจํ•œ ํƒœ์—ฐ':
69
+ st.session_state.TTS = TTS("์ˆ ์ทจํ•œ ํƒœ์—ฐ")
70
 
71
  def update_session_state(state_id, state_value):
72
  st.session_state[f"{state_id}"] = state_value
 
97
  value = "๋ฐฅ์€ ๋จน๊ณ  ๋‹ค๋…€?",
98
  )
99
  with col2:
100
+ model_variant = st.selectbox("๋ชฉ์†Œ๋ฆฌ ์„ ํƒํ•ด์ฃผ์„ธ์š”", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹", "์ˆ ์ทจํ•œ ํƒœ์—ฐ"], index = 1)
101
 
102
  button_change = st.button("Change Vocie")
103
  if button_change == True:
 
107
  update_model()
108
  st.success('Done!', icon="โœ…")
109
 
110
+ noise_scale = st.slider('noise๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 0.3, step = 0.1)
111
+ length_scale = st.slider('์†๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 1., step = 0.1)
112
 
113
  button_gen = st.button("Generate Voice")
114
  if button_gen == True: