marigold334 commited on
Commit
7d58dde
1 Parent(s): b318680

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -23,6 +23,8 @@ class TTS:
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  elif model_variant == 'KSS':
25
  last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
 
 
26
  check_point = torch.load(last_chpt1, map_location = device)
27
  self.flowgenerator.load_state_dict(check_point['generator'])
28
  self.flowgenerator.decoder.skip()
@@ -31,6 +33,8 @@ class TTS:
31
  last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
32
  elif model_variant == 'KSS':
33
  last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
 
 
34
  check_point = torch.load(last_chpt2, map_location = device)
35
  self.voicegenerator.load_state_dict(check_point['gen_model'])
36
  self.voicegenerator.eval()
@@ -54,14 +58,16 @@ def init_session_state():
54
  # Model
55
  if "init_model" not in st.session_state:
56
  st.session_state.init_model = True
57
- st.session_state.model_variant = "감기걸린 은식"
58
- st.session_state.TTS = TTS("감기걸린 은식")
59
 
60
  def update_model():
61
  if st.session_state.model_variant == "KSS":
62
  st.session_state.TTS = TTS("KSS")
63
  elif st.session_state.model_variant == "감기걸린 은식":
64
  st.session_state.TTS = TTS("감기걸린 은식")
 
 
65
 
66
  def update_session_state(state_id, state_value):
67
  st.session_state[f"{state_id}"] = state_value
@@ -89,19 +95,18 @@ col1, col2 = st.columns(2)
89
  with col1:
90
  input_text = st.text_input(
91
  "한글로만 입력해주세요",
92
- value = "밥은 먹고 다녀",
93
  )
94
  with col2:
95
- model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "감기걸린 은식"], index = 1)
96
 
97
  button_change = st.button("Change Vocie")
98
  if button_change == True:
99
  if model_variant != st.session_state.model_variant:
100
- # Update variant choice
101
- update_session_state("model_variant", model_variant)
102
- # Re-load model
103
- update_model()
104
- st.snow()
105
 
106
  noise_scale = st.slider('noise를 추가합니다.', 0., 2., value = 0.33, step = 0.01)
107
  length_scale = st.slider('속도를 조절합니다.', 0., 2., value = 1., step = 0.01)
 
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  elif model_variant == 'KSS':
25
  last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
26
+ elif model_variant == '태연':
27
+ last_chpt1 = './log/Taeyeon/Glow_TTS_400000.pt'
28
  check_point = torch.load(last_chpt1, map_location = device)
29
  self.flowgenerator.load_state_dict(check_point['generator'])
30
  self.flowgenerator.decoder.skip()
 
33
  last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
34
  elif model_variant == 'KSS':
35
  last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
36
+ elif model_variant == '태연':
37
+ last_chpt2 = './log/Taeyeon/HiFi_GAN_337000.pt'
38
  check_point = torch.load(last_chpt2, map_location = device)
39
  self.voicegenerator.load_state_dict(check_point['gen_model'])
40
  self.voicegenerator.eval()
 
58
  # Model
59
  if "init_model" not in st.session_state:
60
  st.session_state.init_model = True
61
+ st.session_state.model_variant = "태연"
62
+ st.session_state.TTS = TTS("태연")
63
 
64
  def update_model():
65
  if st.session_state.model_variant == "KSS":
66
  st.session_state.TTS = TTS("KSS")
67
  elif st.session_state.model_variant == "감기걸린 은식":
68
  st.session_state.TTS = TTS("감기걸린 은식")
69
+ elif st.seesion_state.model_varaiant == '태연':
70
+ st.session_state.TTS = TTS("태연")
71
 
72
  def update_session_state(state_id, state_value):
73
  st.session_state[f"{state_id}"] = state_value
 
95
  with col1:
96
  input_text = st.text_input(
97
  "한글로만 입력해주세요",
98
+ value = "밥은 먹고 다녀?",
99
  )
100
  with col2:
101
+ model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "감기걸린 은식", "태연"], index = 1)
102
 
103
  button_change = st.button("Change Vocie")
104
  if button_change == True:
105
  if model_variant != st.session_state.model_variant:
106
+ with st.spinner('Wait for it...'):
107
+ update_session_state("model_variant", model_variant)
108
+ update_model()
109
+ st.success('Done!', icon="✅")
 
110
 
111
  noise_scale = st.slider('noise를 추가합니다.', 0., 2., value = 0.33, step = 0.01)
112
  length_scale = st.slider('속도를 조절합니다.', 0., 2., value = 1., step = 0.01)