Spaces:
Runtime error
Runtime error
marigold334
commited on
Commit
โข
26dee9c
1
Parent(s):
f523506
Update app.py
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ class TTS:
|
|
18 |
torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
|
19 |
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device)
|
20 |
self.voicegenerator = GAN_model().to(device)
|
21 |
-
if model_variant == '์์':
|
22 |
name = '1038_eunsik_01'
|
23 |
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
24 |
elif model_variant == 'KSS':
|
@@ -27,7 +27,7 @@ class TTS:
|
|
27 |
self.flowgenerator.load_state_dict(check_point['generator'])
|
28 |
self.flowgenerator.decoder.skip()
|
29 |
self.flowgenerator.eval()
|
30 |
-
if model_variant == '์์':
|
31 |
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
|
32 |
elif model_variant == 'KSS':
|
33 |
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
|
@@ -36,7 +36,7 @@ class TTS:
|
|
36 |
self.voicegenerator.eval()
|
37 |
self.voicegenerator.remove_weight_norm()
|
38 |
|
39 |
-
def inference(self,
|
40 |
filters = '([.,!?])'
|
41 |
sentence = re.sub(re.compile(filters), '', input_text)
|
42 |
x = text_to_sequence(sentence)
|
@@ -44,8 +44,6 @@ class TTS:
|
|
44 |
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
|
45 |
|
46 |
with torch.no_grad():
|
47 |
-
noise_scale = .667
|
48 |
-
length_scale = 1.0
|
49 |
(y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
|
50 |
y = self.voicegenerator(y_gen_tst)
|
51 |
audio = y.squeeze() * 32768.0
|
@@ -56,14 +54,14 @@ def init_session_state():
|
|
56 |
# Model
|
57 |
if "init_model" not in st.session_state:
|
58 |
st.session_state.init_model = True
|
59 |
-
st.session_state.model_variant = "์์"
|
60 |
-
st.session_state.TTS = TTS("์์")
|
61 |
|
62 |
def update_model():
|
63 |
if st.session_state.model_variant == "KSS":
|
64 |
st.session_state.TTS = TTS("KSS")
|
65 |
-
elif st.session_state.model_variant == "์์":
|
66 |
-
st.session_state.TTS = TTS("์์")
|
67 |
|
68 |
def update_session_state(state_id, state_value):
|
69 |
st.session_state[f"{state_id}"] = state_value
|
@@ -89,7 +87,7 @@ st.write(" ")
|
|
89 |
|
90 |
mode = "p"
|
91 |
st.markdown(
|
92 |
-
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by KSS Dataset
|
93 |
unsafe_allow_html = True
|
94 |
)
|
95 |
|
@@ -100,10 +98,10 @@ col1, col2 = st.columns(2)
|
|
100 |
with col1:
|
101 |
input_text = st.text_input(
|
102 |
"ํ๊ธ๋ก๋ง ์
๋ ฅํด์ฃผ์ธ์",
|
103 |
-
value = "
|
104 |
)
|
105 |
with col2:
|
106 |
-
model_variant = st.selectbox("๋ชฉ์๋ฆฌ ์ ํํด์ฃผ์ธ์", options = ["KSS", "์์"], index = 1)
|
107 |
if model_variant != st.session_state.model_variant:
|
108 |
# Update variant choice
|
109 |
update_session_state("model_variant", model_variant)
|
@@ -111,9 +109,11 @@ with col2:
|
|
111 |
update_model()
|
112 |
st.snow()
|
113 |
|
|
|
|
|
114 |
button_gen = st.button("Generate Voice")
|
115 |
if button_gen == True:
|
116 |
-
generate_voice(input_text)
|
117 |
st.balloons()
|
118 |
|
119 |
|
|
|
18 |
torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
|
19 |
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device)
|
20 |
self.voicegenerator = GAN_model().to(device)
|
21 |
+
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์':
|
22 |
name = '1038_eunsik_01'
|
23 |
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
24 |
elif model_variant == 'KSS':
|
|
|
27 |
self.flowgenerator.load_state_dict(check_point['generator'])
|
28 |
self.flowgenerator.decoder.skip()
|
29 |
self.flowgenerator.eval()
|
30 |
+
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์':
|
31 |
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
|
32 |
elif model_variant == 'KSS':
|
33 |
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
|
|
|
36 |
self.voicegenerator.eval()
|
37 |
self.voicegenerator.remove_weight_norm()
|
38 |
|
39 |
+
def inference(self, input_textm, noise_scale = 0.667, length_scale = 1.0):
|
40 |
filters = '([.,!?])'
|
41 |
sentence = re.sub(re.compile(filters), '', input_text)
|
42 |
x = text_to_sequence(sentence)
|
|
|
44 |
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
|
45 |
|
46 |
with torch.no_grad():
|
|
|
|
|
47 |
(y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
|
48 |
y = self.voicegenerator(y_gen_tst)
|
49 |
audio = y.squeeze() * 32768.0
|
|
|
54 |
# Model
|
55 |
if "init_model" not in st.session_state:
|
56 |
st.session_state.init_model = True
|
57 |
+
st.session_state.model_variant = "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์"
|
58 |
+
st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์")
|
59 |
|
60 |
def update_model():
|
61 |
if st.session_state.model_variant == "KSS":
|
62 |
st.session_state.TTS = TTS("KSS")
|
63 |
+
elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์":
|
64 |
+
st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์")
|
65 |
|
66 |
def update_session_state(state_id, state_value):
|
67 |
st.session_state[f"{state_id}"] = state_value
|
|
|
87 |
|
88 |
mode = "p"
|
89 |
st.markdown(
|
90 |
+
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by <a href= 'https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset'>KSS Dataset</a>. The voice \"๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
|
91 |
unsafe_allow_html = True
|
92 |
)
|
93 |
|
|
|
98 |
with col1:
|
99 |
input_text = st.text_input(
|
100 |
"ํ๊ธ๋ก๋ง ์
๋ ฅํด์ฃผ์ธ์",
|
101 |
+
value = "๋ฐฅ์ ๋จน๊ณ ๋ค๋๋?",
|
102 |
)
|
103 |
with col2:
|
104 |
+
model_variant = st.selectbox("๋ชฉ์๋ฆฌ ์ ํํด์ฃผ์ธ์", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์"], index = 1)
|
105 |
if model_variant != st.session_state.model_variant:
|
106 |
# Update variant choice
|
107 |
update_session_state("model_variant", model_variant)
|
|
|
109 |
update_model()
|
110 |
st.snow()
|
111 |
|
112 |
+
noise_scale = st.slider('noise๋ฅผ ์ถ๊ฐํฉ๋๋ค.', 0, 1, value = 0.66, step = 0.01)
|
113 |
+
length_scale = st.slider('์๋๋ฅผ ์กฐ์ ํฉ๋๋ค.', 0, 2, value = 1., step = 0.01)
|
114 |
button_gen = st.button("Generate Voice")
|
115 |
if button_gen == True:
|
116 |
+
generate_voice(input_text, noise_scale, length_scale)
|
117 |
st.balloons()
|
118 |
|
119 |
|