File size: 4,876 Bytes
41989ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
import soundfile as sf
import timeit
import uuid

import os

import torch

from datautils import *
from model import Generator as Glow_model
from utils import scan_checkpoint, plot_mel, plot_alignment
from Hmodel import Generator as GAN_model

MAX_WAV_VALUE = 32768.0
device = torch.device('cuda:0')
torch.cuda.manual_seed(1234)
name = '1038_eunsik_01'

# Nix
from nix.models.TTS import NixTTSInference

def init_session_state():
    # Model
    if "init_model" not in st.session_state:
        st.session_state.init_model = True
        st.session_state.model_variant = "KSS"
        st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")

def update_model():
    if st.session_state.model_variant == "KSS":
        st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-v0.1")
    elif st.session_state.model_variant == "은식":
        st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")

def update_session_state(state_id, state_value):
    st.session_state[f"{state_id}"] = state_value
    
def centered_text(input_text, mode = "h1",):
    st.markdown(
        f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)

def generate_voice(input_text,):
    # TTS Inference
    c, c_length, phoneme = st.session_state.TTS.tokenize(input_text)
    voice = st.session_state.TTS.vocalize(c, c_length)

    # Save audio (bug in Streamlit, can't play numpy array directly)
    sf.write(f"cache_sound/{input_text}.wav", voice[0,0], 22050)
    
    # Play audio
    st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
    os.remove(f"cache_sound/{input_text}.wav")
    st.caption("Generated Voice")
    
st.set_page_config(
    page_title = "μ†Œμ‹  Team Demo",
    page_icon = "πŸ”‰",
)

init_session_state()

centered_text("πŸ”‰ μ†Œμ‹  Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")

mode = "p"
st.markdown(
    f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie.&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; The voice \"KSS\" is traind 3 times \"은식\" is finetuned from \"KSS\" for 3 times &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
    unsafe_allow_html = True
)

st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)

with col1:
    input_text = st.text_input(
        "ν•œκΈ€λ‘œλ§Œ μž…λ ₯ν•΄μ£Όμ„Έμš”",
        value = "λ”₯λŸ¬λ‹μ€ 정말 μž¬λ°Œμ–΄!",
    )
with col2:
    model_variant = st.selectbox("λͺ©μ†Œλ¦¬ μ„ νƒν•΄μ£Όμ„Έμš”", options = ["KSS", "은식"], index = 1)
    if model_variant != st.session_state.model_variant:
        # Update variant choice
        update_session_state("model_variant", model_variant)
        # Re-load model
        update_model()

button_gen = st.button("Generate Voice")
if button_gen == True:
    generate_voice(input_text)


class TTS:
    def __init__(self, model_variant):
        self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
        self.voicegenerator = GAN_model()
        if model_variant == '은식':
            last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
        check_point = torch.load(last_chpt1)
        self.flowgenerator.load_state_dict(check_point['generator'])
        self.flowgenerator.decoder.skip()
        self.flowgenerator.eval()
        if model_variant == '은식':
            last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
        check_point = torch.load(last_chpt2)
        self.voicegenerator.load_state_dict(check_point['gen_model'])
        self.voicegenerator.eval()
        self.voicegenerator.remove_weight_norm()
    
    def inference(self, input_text):
        x = text_to_sequence(sentence)
        filters = '([.,!?])'
        sentence = re.sub(re.compile(filters), '', text)
        x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
        x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
        
        with torch.no_grad():
            noise_scale = .667
            length_scale = 1.0
            (y_gen_tst, *_), *_, (attn_gen, *_) = flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
            y = voicegenerator(y_gen_tst)
            audio = y.squeeze() * MAX_WAV_VALUE
            audio = audio.cpu().numpy().astype('int16')
                    
            output_file = os.path.join(out_dir, 'gen_'+text[:3]+'.wav')
            write(output_file, 22050, audio)
            print(f'{text} is stored in {out_dir}')
        
        return voice
plot_mel(y_gen_tst[0].data.cpu().numpy())
plot_alignment(attn_gen[0,0].data.cpu().numpy(), sequence_to_text(x[0].data.cpu().numpy()))
ipd.display(fig1,fig2)
ipd.Audio(filename=output_file)