File size: 5,343 Bytes
41989ff
 
8fde97d
41989ff
 
 
 
8ebd890
41989ff
dcef047
 
 
 
41989ff
8fde97d
 
dcef047
371ba49
dcef047
217e8d0
dcef047
764c666
8fde97d
fa050b7
 
e6cd225
ddbc9f0
4936e8e
816e79a
8346e07
8fde97d
764c666
1258aa9
fa050b7
1258aa9
e6cd225
ddbc9f0
4936e8e
 
8fde97d
 
 
87218c5
c5fff6b
 
8fde97d
 
 
 
 
eba6e07
8fde97d
 
 
 
41989ff
 
 
 
e6cd225
 
41989ff
 
 
8fde97d
764c666
 
48e5667
e6cd225
41989ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764c666
41989ff
 
 
 
 
 
 
 
 
 
e6cd225
41989ff
 
e6cd225
b318680
 
 
41989ff
e6cd225
 
 
 
b318680
87218c5
486e21a
b318680
41989ff
 
87218c5
 
 
fa050b7
41989ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
import soundfile as sf
import os, re
import torch
from datautils import *
from model import Generator as Glow_model
from Hmodel import Generator as GAN_model
from Tmodel import GlowTTS as T_Glow_model

st.set_page_config(
    page_title = "์†Œ์‹  Team Demo",
    page_icon = "๐Ÿ”‰",
)

class TTS:
    def __init__(self, model_variant):
        global device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
        self.flowgenerator = Glow_model(n_vocab = 70 , h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device) if model_variant != 'ํƒœ์—ฐ' else T_Glow_model().to(device)
        self.voicegenerator = GAN_model().to(device)
        if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
            last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
        elif model_variant == 'KSS':
            last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
        elif model_variant == 'ํƒœ์—ฐ':
            last_chpt1 = './log/Taeyeon/Glow_TTS_337000.pt'
        check_point = torch.load(last_chpt1, map_location = device)
        self.flowgenerator.load_state_dict(check_point['generator' if model_variant != 'ํƒœ์—ฐ' else 'model'])
        self.flowgenerator.decoder.skip() if model_variant != 'ํƒœ์—ฐ' else None
        self.flowgenerator.eval()
        if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
            last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
        elif model_variant == 'KSS':
            last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
        elif model_variant == 'ํƒœ์—ฐ':
            last_chpt2 = './log/Taeyeon/HiFi_GAN_400000.pt'
        check_point = torch.load(last_chpt2, map_location = device)
        self.voicegenerator.load_state_dict(check_point['gen_model'])
        self.voicegenerator.eval()
        self.voicegenerator.remove_weight_norm()
    
    def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
        filters = '([.,!?])' if st.session_state != 'ํƒœ์—ฐ' else '([,])'
        sentence = re.sub(re.compile(filters), '', input_text) 
        x = text_to_sequence(sentence)
        x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
        x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
        
        with torch.no_grad():
            (y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale) if st.session_state.model_variant != "ํƒœ์—ฐ" else self.flowgenerator(x, x_length, inference = True, noise_scale = noise_scale, length_scale = length_scale) 
            y = self.voicegenerator(y_gen_tst)
            audio = y.squeeze() * 32768.0
            voice = audio.cpu().numpy().astype('int16')
        return voice

def init_session_state():
    if "init_model" not in st.session_state:
        st.session_state.init_model = True
        st.session_state.model_variant = "ํƒœ์—ฐ"
        st.session_state.TTS = TTS("ํƒœ์—ฐ")

def update_model():
    if st.session_state.model_variant == "KSS":
        st.session_state.TTS = TTS("KSS")
    elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹":
        st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹")
    elif st.session_state.model_variant == 'ํƒœ์—ฐ':
        st.session_state.TTS = TTS("ํƒœ์—ฐ")

def update_session_state(state_id, state_value):
    st.session_state[f"{state_id}"] = state_value
    
def centered_text(input_text, mode = "h1",):
    st.markdown(
        f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)

init_session_state()

centered_text("๐Ÿ”‰ ์†Œ์‹  Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")

mode = "p"
st.markdown(
    f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by <a href= 'https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset'>KSS Dataset</a>. The voice \"๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
    unsafe_allow_html = True
)

st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)

with col1:
    input_text = st.text_input(
        "ํ•œ๊ธ€๋กœ๋งŒ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”",
        value = "๋ฐฅ์€ ๋จน๊ณ  ๋‹ค๋…€?",
    )
with col2:
    model_variant = st.selectbox("๋ชฉ์†Œ๋ฆฌ ์„ ํƒํ•ด์ฃผ์„ธ์š”", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹", "ํƒœ์—ฐ"], index = 1)
    
button_change = st.button("Change Vocie") 
if button_change == True:
    if model_variant != st.session_state.model_variant:
        with st.spinner('Wait for it...'):
            update_session_state("model_variant", model_variant)
            update_model()
        st.success('Done!', icon="โœ…")
    
noise_scale = st.slider('noise๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 0.33, step = 0.01)
length_scale = st.slider('์†๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 1., step = 0.01)

button_gen = st.button("Generate Voice")
if button_gen == True:
    voice = st.session_state.TTS.inference(input_text, noise_scale, length_scale)
    st.audio(voice,sample_rate = 22050)
    st.caption("Generated Voice by" + st.session_state.model_variant)
    st.balloons()