Spaces:
Runtime error
Runtime error
marigold334
commited on
Commit
•
7d58dde
1
Parent(s):
b318680
Update app.py
Browse files
app.py
CHANGED
@@ -23,6 +23,8 @@ class TTS:
|
|
23 |
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
24 |
elif model_variant == 'KSS':
|
25 |
last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
|
|
|
|
|
26 |
check_point = torch.load(last_chpt1, map_location = device)
|
27 |
self.flowgenerator.load_state_dict(check_point['generator'])
|
28 |
self.flowgenerator.decoder.skip()
|
@@ -31,6 +33,8 @@ class TTS:
|
|
31 |
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
|
32 |
elif model_variant == 'KSS':
|
33 |
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
|
|
|
|
|
34 |
check_point = torch.load(last_chpt2, map_location = device)
|
35 |
self.voicegenerator.load_state_dict(check_point['gen_model'])
|
36 |
self.voicegenerator.eval()
|
@@ -54,14 +58,16 @@ def init_session_state():
|
|
54 |
# Model
|
55 |
if "init_model" not in st.session_state:
|
56 |
st.session_state.init_model = True
|
57 |
-
st.session_state.model_variant = "
|
58 |
-
st.session_state.TTS = TTS("
|
59 |
|
60 |
def update_model():
|
61 |
if st.session_state.model_variant == "KSS":
|
62 |
st.session_state.TTS = TTS("KSS")
|
63 |
elif st.session_state.model_variant == "감기걸린 은식":
|
64 |
st.session_state.TTS = TTS("감기걸린 은식")
|
|
|
|
|
65 |
|
66 |
def update_session_state(state_id, state_value):
|
67 |
st.session_state[f"{state_id}"] = state_value
|
@@ -89,19 +95,18 @@ col1, col2 = st.columns(2)
|
|
89 |
with col1:
|
90 |
input_text = st.text_input(
|
91 |
"한글로만 입력해주세요",
|
92 |
-
value = "밥은 먹고
|
93 |
)
|
94 |
with col2:
|
95 |
-
model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "감기걸린 은식"], index = 1)
|
96 |
|
97 |
button_change = st.button("Change Vocie")
|
98 |
if button_change == True:
|
99 |
if model_variant != st.session_state.model_variant:
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
st.snow()
|
105 |
|
106 |
noise_scale = st.slider('noise를 추가합니다.', 0., 2., value = 0.33, step = 0.01)
|
107 |
length_scale = st.slider('속도를 조절합니다.', 0., 2., value = 1., step = 0.01)
|
|
|
23 |
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
|
24 |
elif model_variant == 'KSS':
|
25 |
last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
|
26 |
+
elif model_variant == '태연':
|
27 |
+
last_chpt1 = './log/Taeyeon/Glow_TTS_400000.pt'
|
28 |
check_point = torch.load(last_chpt1, map_location = device)
|
29 |
self.flowgenerator.load_state_dict(check_point['generator'])
|
30 |
self.flowgenerator.decoder.skip()
|
|
|
33 |
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
|
34 |
elif model_variant == 'KSS':
|
35 |
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
|
36 |
+
elif model_variant == '태연':
|
37 |
+
last_chpt2 = './log/Taeyeon/HiFi_GAN_337000.pt'
|
38 |
check_point = torch.load(last_chpt2, map_location = device)
|
39 |
self.voicegenerator.load_state_dict(check_point['gen_model'])
|
40 |
self.voicegenerator.eval()
|
|
|
58 |
# Model
|
59 |
if "init_model" not in st.session_state:
|
60 |
st.session_state.init_model = True
|
61 |
+
st.session_state.model_variant = "태연"
|
62 |
+
st.session_state.TTS = TTS("태연")
|
63 |
|
64 |
def update_model():
|
65 |
if st.session_state.model_variant == "KSS":
|
66 |
st.session_state.TTS = TTS("KSS")
|
67 |
elif st.session_state.model_variant == "감기걸린 은식":
|
68 |
st.session_state.TTS = TTS("감기걸린 은식")
|
69 |
+
elif st.seesion_state.model_varaiant == '태연':
|
70 |
+
st.session_state.TTS = TTS("태연")
|
71 |
|
72 |
def update_session_state(state_id, state_value):
|
73 |
st.session_state[f"{state_id}"] = state_value
|
|
|
95 |
with col1:
|
96 |
input_text = st.text_input(
|
97 |
"한글로만 입력해주세요",
|
98 |
+
value = "밥은 먹고 다녀?",
|
99 |
)
|
100 |
with col2:
|
101 |
+
model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "감기걸린 은식", "태연"], index = 1)
|
102 |
|
103 |
button_change = st.button("Change Vocie")
|
104 |
if button_change == True:
|
105 |
if model_variant != st.session_state.model_variant:
|
106 |
+
with st.spinner('Wait for it...'):
|
107 |
+
update_session_state("model_variant", model_variant)
|
108 |
+
update_model()
|
109 |
+
st.success('Done!', icon="✅")
|
|
|
110 |
|
111 |
noise_scale = st.slider('noise를 추가합니다.', 0., 2., value = 0.33, step = 0.01)
|
112 |
length_scale = st.slider('속도를 조절합니다.', 0., 2., value = 1., step = 0.01)
|