Spaces:
Runtime error
Runtime error
marigold334
commited on
Commit
โข
e49dd71
1
Parent(s):
bb04426
Update app.py
Browse files
app.py
CHANGED
@@ -23,17 +23,17 @@ class TTS:
|
|
23 |
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
24 |
elif model_variant == 'KSS':
|
25 |
last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
|
26 |
-
elif model_variant == 'ํ์ฐ':
|
27 |
last_chpt1 = './log/Taeyeon/Glow_TTS_337000.pt'
|
28 |
check_point = torch.load(last_chpt1, map_location = device)
|
29 |
-
self.flowgenerator.load_state_dict(check_point['generator' if model_variant != 'ํ์ฐ' else 'model'])
|
30 |
-
self.flowgenerator.decoder.skip() if model_variant != 'ํ์ฐ' else None
|
31 |
self.flowgenerator.eval()
|
32 |
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์':
|
33 |
-
last_chpt2 = './log/1038_eunsik_01/
|
34 |
elif model_variant == 'KSS':
|
35 |
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
|
36 |
-
elif model_variant == 'ํ์ฐ':
|
37 |
last_chpt2 = './log/Taeyeon/HiFi_GAN_400000.pt'
|
38 |
check_point = torch.load(last_chpt2, map_location = device)
|
39 |
self.voicegenerator.load_state_dict(check_point['gen_model'])
|
@@ -41,7 +41,7 @@ class TTS:
|
|
41 |
self.voicegenerator.remove_weight_norm()
|
42 |
|
43 |
def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
|
44 |
-
filters = '([.,!?])' if st.session_state != 'ํ์ฐ' else '([,])'
|
45 |
sentence = re.sub(re.compile(filters), '', input_text)
|
46 |
x = text_to_sequence(sentence)
|
47 |
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
|
@@ -57,16 +57,16 @@ class TTS:
|
|
57 |
def init_session_state():
|
58 |
if "init_model" not in st.session_state:
|
59 |
st.session_state.init_model = True
|
60 |
-
st.session_state.model_variant = "ํ์ฐ"
|
61 |
-
st.session_state.TTS = TTS("ํ์ฐ")
|
62 |
|
63 |
def update_model():
|
64 |
if st.session_state.model_variant == "KSS":
|
65 |
st.session_state.TTS = TTS("KSS")
|
66 |
elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์":
|
67 |
st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์")
|
68 |
-
elif st.session_state.model_variant == 'ํ์ฐ':
|
69 |
-
st.session_state.TTS = TTS("ํ์ฐ")
|
70 |
|
71 |
def update_session_state(state_id, state_value):
|
72 |
st.session_state[f"{state_id}"] = state_value
|
@@ -97,7 +97,7 @@ with col1:
|
|
97 |
value = "๋ฐฅ์ ๋จน๊ณ ๋ค๋
?",
|
98 |
)
|
99 |
with col2:
|
100 |
-
model_variant = st.selectbox("๋ชฉ์๋ฆฌ ์ ํํด์ฃผ์ธ์", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์", "ํ์ฐ"], index = 1)
|
101 |
|
102 |
button_change = st.button("Change Vocie")
|
103 |
if button_change == True:
|
@@ -107,8 +107,8 @@ if button_change == True:
|
|
107 |
update_model()
|
108 |
st.success('Done!', icon="โ
")
|
109 |
|
110 |
-
noise_scale = st.slider('noise๋ฅผ ์ถ๊ฐํฉ๋๋ค.', 0., 2., value = 0.
|
111 |
-
length_scale = st.slider('์๋๋ฅผ ์กฐ์ ํฉ๋๋ค.', 0., 2., value = 1., step = 0.
|
112 |
|
113 |
button_gen = st.button("Generate Voice")
|
114 |
if button_gen == True:
|
|
|
23 |
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
24 |
elif model_variant == 'KSS':
|
25 |
last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
|
26 |
+
elif model_variant == '์ ์ทจํ ํ์ฐ':
|
27 |
last_chpt1 = './log/Taeyeon/Glow_TTS_337000.pt'
|
28 |
check_point = torch.load(last_chpt1, map_location = device)
|
29 |
+
self.flowgenerator.load_state_dict(check_point['generator' if model_variant != '์ ์ทจํ ํ์ฐ' else 'model'])
|
30 |
+
self.flowgenerator.decoder.skip() if model_variant != '์ ์ทจํ ํ์ฐ' else None
|
31 |
self.flowgenerator.eval()
|
32 |
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์':
|
33 |
+
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00664000.pt'
|
34 |
elif model_variant == 'KSS':
|
35 |
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
|
36 |
+
elif model_variant == '์ ์ทจํ ํ์ฐ':
|
37 |
last_chpt2 = './log/Taeyeon/HiFi_GAN_400000.pt'
|
38 |
check_point = torch.load(last_chpt2, map_location = device)
|
39 |
self.voicegenerator.load_state_dict(check_point['gen_model'])
|
|
|
41 |
self.voicegenerator.remove_weight_norm()
|
42 |
|
43 |
def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
|
44 |
+
filters = '([.,!?])' if st.session_state != '์ ์ทจํ ํ์ฐ' else '([,])'
|
45 |
sentence = re.sub(re.compile(filters), '', input_text)
|
46 |
x = text_to_sequence(sentence)
|
47 |
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
|
|
|
57 |
def init_session_state():
|
58 |
if "init_model" not in st.session_state:
|
59 |
st.session_state.init_model = True
|
60 |
+
st.session_state.model_variant = "์ ์ทจํ ํ์ฐ"
|
61 |
+
st.session_state.TTS = TTS("์ ์ทจํ ํ์ฐ")
|
62 |
|
63 |
def update_model():
|
64 |
if st.session_state.model_variant == "KSS":
|
65 |
st.session_state.TTS = TTS("KSS")
|
66 |
elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์":
|
67 |
st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์")
|
68 |
+
elif st.session_state.model_variant == '์ ์ทจํ ํ์ฐ':
|
69 |
+
st.session_state.TTS = TTS("์ ์ทจํ ํ์ฐ")
|
70 |
|
71 |
def update_session_state(state_id, state_value):
|
72 |
st.session_state[f"{state_id}"] = state_value
|
|
|
97 |
value = "๋ฐฅ์ ๋จน๊ณ ๋ค๋
?",
|
98 |
)
|
99 |
with col2:
|
100 |
+
model_variant = st.selectbox("๋ชฉ์๋ฆฌ ์ ํํด์ฃผ์ธ์", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์", "์ ์ทจํ ํ์ฐ"], index = 1)
|
101 |
|
102 |
button_change = st.button("Change Vocie")
|
103 |
if button_change == True:
|
|
|
107 |
update_model()
|
108 |
st.success('Done!', icon="โ
")
|
109 |
|
110 |
+
noise_scale = st.slider('noise๋ฅผ ์ถ๊ฐํฉ๋๋ค.', 0., 2., value = 0.3, step = 0.1)
|
111 |
+
length_scale = st.slider('์๋๋ฅผ ์กฐ์ ํฉ๋๋ค.', 0., 2., value = 1., step = 0.1)
|
112 |
|
113 |
button_gen = st.button("Generate Voice")
|
114 |
if button_gen == True:
|