marigold334 commited on
Commit
26dee9c
โ€ข
1 Parent(s): f523506

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -18,7 +18,7 @@ class TTS:
18
  torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
19
  self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device)
20
  self.voicegenerator = GAN_model().to(device)
21
- if model_variant == '์€์‹':
22
  name = '1038_eunsik_01'
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  elif model_variant == 'KSS':
@@ -27,7 +27,7 @@ class TTS:
27
  self.flowgenerator.load_state_dict(check_point['generator'])
28
  self.flowgenerator.decoder.skip()
29
  self.flowgenerator.eval()
30
- if model_variant == '์€์‹':
31
  last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
32
  elif model_variant == 'KSS':
33
  last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
@@ -36,7 +36,7 @@ class TTS:
36
  self.voicegenerator.eval()
37
  self.voicegenerator.remove_weight_norm()
38
 
39
- def inference(self, input_text):
40
  filters = '([.,!?])'
41
  sentence = re.sub(re.compile(filters), '', input_text)
42
  x = text_to_sequence(sentence)
@@ -44,8 +44,6 @@ class TTS:
44
  x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
45
 
46
  with torch.no_grad():
47
- noise_scale = .667
48
- length_scale = 1.0
49
  (y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
50
  y = self.voicegenerator(y_gen_tst)
51
  audio = y.squeeze() * 32768.0
@@ -56,14 +54,14 @@ def init_session_state():
56
  # Model
57
  if "init_model" not in st.session_state:
58
  st.session_state.init_model = True
59
- st.session_state.model_variant = "์€์‹"
60
- st.session_state.TTS = TTS("์€์‹")
61
 
62
  def update_model():
63
  if st.session_state.model_variant == "KSS":
64
  st.session_state.TTS = TTS("KSS")
65
- elif st.session_state.model_variant == "์€์‹":
66
- st.session_state.TTS = TTS("์€์‹")
67
 
68
  def update_session_state(state_id, state_value):
69
  st.session_state[f"{state_id}"] = state_value
@@ -89,7 +87,7 @@ st.write(" ")
89
 
90
  mode = "p"
91
  st.markdown(
92
- f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by KSS Dataset. \"์€์‹\" which is about 1 hour audio is finetuned from \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
93
  unsafe_allow_html = True
94
  )
95
 
@@ -100,10 +98,10 @@ col1, col2 = st.columns(2)
100
  with col1:
101
  input_text = st.text_input(
102
  "ํ•œ๊ธ€๋กœ๋งŒ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”",
103
- value = "๋”ฅ๋Ÿฌ๋‹์€ ์ •๋ง ์žฌ๋ฐŒ์–ด!",
104
  )
105
  with col2:
106
- model_variant = st.selectbox("๋ชฉ์†Œ๋ฆฌ ์„ ํƒํ•ด์ฃผ์„ธ์š”", options = ["KSS", "์€์‹"], index = 1)
107
  if model_variant != st.session_state.model_variant:
108
  # Update variant choice
109
  update_session_state("model_variant", model_variant)
@@ -111,9 +109,11 @@ with col2:
111
  update_model()
112
  st.snow()
113
 
 
 
114
  button_gen = st.button("Generate Voice")
115
  if button_gen == True:
116
- generate_voice(input_text)
117
  st.balloons()
118
 
119
 
 
18
  torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
19
  self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device)
20
  self.voicegenerator = GAN_model().to(device)
21
+ if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
22
  name = '1038_eunsik_01'
23
  last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
24
  elif model_variant == 'KSS':
 
27
  self.flowgenerator.load_state_dict(check_point['generator'])
28
  self.flowgenerator.decoder.skip()
29
  self.flowgenerator.eval()
30
+ if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
31
  last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
32
  elif model_variant == 'KSS':
33
  last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
 
36
  self.voicegenerator.eval()
37
  self.voicegenerator.remove_weight_norm()
38
 
39
+ def inference(self, input_textm, noise_scale = 0.667, length_scale = 1.0):
40
  filters = '([.,!?])'
41
  sentence = re.sub(re.compile(filters), '', input_text)
42
  x = text_to_sequence(sentence)
 
44
  x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
45
 
46
  with torch.no_grad():
 
 
47
  (y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
48
  y = self.voicegenerator(y_gen_tst)
49
  audio = y.squeeze() * 32768.0
 
54
  # Model
55
  if "init_model" not in st.session_state:
56
  st.session_state.init_model = True
57
+ st.session_state.model_variant = "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹"
58
+ st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹")
59
 
60
  def update_model():
61
  if st.session_state.model_variant == "KSS":
62
  st.session_state.TTS = TTS("KSS")
63
+ elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹":
64
+ st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹")
65
 
66
  def update_session_state(state_id, state_value):
67
  st.session_state[f"{state_id}"] = state_value
 
87
 
88
  mode = "p"
89
  st.markdown(
90
+ f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by <a href= 'https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset'>KSS Dataset</a>. The voice \"๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
91
  unsafe_allow_html = True
92
  )
93
 
 
98
  with col1:
99
  input_text = st.text_input(
100
  "ํ•œ๊ธ€๋กœ๋งŒ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”",
101
+ value = "๋ฐฅ์€ ๋จน๊ณ ๋‹ค๋‹ˆ๋ƒ?",
102
  )
103
  with col2:
104
+ model_variant = st.selectbox("๋ชฉ์†Œ๋ฆฌ ์„ ํƒํ•ด์ฃผ์„ธ์š”", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹"], index = 1)
105
  if model_variant != st.session_state.model_variant:
106
  # Update variant choice
107
  update_session_state("model_variant", model_variant)
 
109
  update_model()
110
  st.snow()
111
 
112
+ noise_scale = st.slider('noise๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.', 0, 1, value = 0.66, step = 0.01)
113
+ length_scale = st.slider('์†๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.', 0, 2, value = 1., step = 0.01)
114
  button_gen = st.button("Generate Voice")
115
  if button_gen == True:
116
+ generate_voice(input_text, noise_scale, length_scale)
117
  st.balloons()
118
 
119