File size: 4,738 Bytes
41989ff
 
8fde97d
41989ff
 
 
 
 
dcef047
 
 
 
41989ff
8fde97d
 
dcef047
371ba49
dcef047
 
 
26dee9c
8fde97d
 
fa050b7
 
4936e8e
 
8fde97d
 
26dee9c
1258aa9
fa050b7
1258aa9
4936e8e
 
8fde97d
 
 
26dee9c
8fde97d
 
 
 
 
 
 
 
 
 
 
 
41989ff
 
 
 
 
26dee9c
 
41989ff
 
 
8fde97d
26dee9c
 
41989ff
 
 
 
 
 
 
 
8fde97d
41989ff
8fde97d
41989ff
 
b35e31e
b1f349f
dcef047
41989ff
 
 
 
 
 
 
 
 
26dee9c
41989ff
 
 
 
 
 
 
 
 
 
26dee9c
41989ff
 
26dee9c
41989ff
 
 
 
 
f523506
 
26dee9c
 
41989ff
 
26dee9c
fa050b7
41989ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import soundfile as sf
import os, re
import torch
from datautils import *
from model import Generator as Glow_model
from Hmodel import Generator as GAN_model

st.set_page_config(
    page_title = "소신 Team Demo",
    page_icon = "🔉",
)

class TTS:
    def __init__(self, model_variant):
        global device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
        self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device)
        self.voicegenerator = GAN_model().to(device)
        if model_variant == '감기걸린 은식':
            name = '1038_eunsik_01'
            last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
        elif model_variant == 'KSS':
            last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
        check_point = torch.load(last_chpt1, map_location = device)
        self.flowgenerator.load_state_dict(check_point['generator'])
        self.flowgenerator.decoder.skip()
        self.flowgenerator.eval()
        if model_variant == '감기걸린 은식':
            last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
        elif model_variant == 'KSS':
            last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
        check_point = torch.load(last_chpt2, map_location = device)
        self.voicegenerator.load_state_dict(check_point['gen_model'])
        self.voicegenerator.eval()
        self.voicegenerator.remove_weight_norm()
    
    def inference(self, input_textm, noise_scale = 0.667, length_scale = 1.0):
        filters = '([.,!?])'
        sentence = re.sub(re.compile(filters), '', input_text)
        x = text_to_sequence(sentence)
        x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
        x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
        
        with torch.no_grad():
            (y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
            y = self.voicegenerator(y_gen_tst)
            audio = y.squeeze() * 32768.0
            voice = audio.cpu().numpy().astype('int16')
        return voice

def init_session_state():
    # Model
    if "init_model" not in st.session_state:
        st.session_state.init_model = True
        st.session_state.model_variant = "감기걸린 은식"
        st.session_state.TTS = TTS("감기걸린 은식")

def update_model():
    if st.session_state.model_variant == "KSS":
        st.session_state.TTS = TTS("KSS")
    elif st.session_state.model_variant == "감기걸린 은식":
        st.session_state.TTS = TTS("감기걸린 은식")

def update_session_state(state_id, state_value):
    st.session_state[f"{state_id}"] = state_value
    
def centered_text(input_text, mode = "h1",):
    st.markdown(
        f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)

def generate_voice(input_text):
    # TTS Inference
    voice = st.session_state.TTS.inference(input_text)
    
    # Play audio
    st.audio(voice,sample_rate = 22050)
    st.caption("Generated Voice by" + st.session_state.model_variant)


init_session_state()

centered_text("🔉 소신 Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")

mode = "p"
st.markdown(
    f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by <a href= 'https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset'>KSS Dataset</a>. The voice \"감기걸린 은식\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
    unsafe_allow_html = True
)

st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)

with col1:
    input_text = st.text_input(
        "한글로만 입력해주세요",
        value = "밥은 먹고다니냐?",
    )
with col2:
    model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "감기걸린 은식"], index = 1)
    if model_variant != st.session_state.model_variant:
        # Update variant choice
        update_session_state("model_variant", model_variant)
        # Re-load model
        update_model()
        st.snow()
        
noise_scale = st.slider('noise를 추가합니다.', 0, 1, value = 0.66, step = 0.01)
length_scale = st.slider('속도를 조절합니다.', 0, 2, value = 1., step = 0.01)
button_gen = st.button("Generate Voice")
if button_gen == True:
    generate_voice(input_text, noise_scale, length_scale)
    st.balloons()