Glow-HiFi-TTS / app.py
marigold334's picture
Update app.py (#37)
d8c4b79
raw
history blame
5.44 kB
import streamlit as st
import soundfile as sf
import os, re
import torch
from datautils import *
from model import Generator as Glow_model
from Hmodel import Generator as GAN_model
from Tmodel import GlowTTS as T_Glow_model
st.set_page_config(
page_title = "์†Œ์‹  Team Demo",
page_icon = "๐Ÿ”‰",
)
class TTS:
def __init__(self, model_variant):
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
self.flowgenerator = Glow_model(n_vocab = 70 , h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device) if model_variant != 'ํƒœ์—ฐ' else T_Glow_model().to(device)
self.voicegenerator = GAN_model().to(device)
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
elif model_variant == 'KSS':
last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
elif model_variant == '์ˆ ์ทจํ•œ ํƒœ์—ฐ':
last_chpt1 = './log/Taeyeon/Glow_TTS_337000.pt'
check_point = torch.load(last_chpt1, map_location = device)
self.flowgenerator.load_state_dict(check_point['generator' if model_variant != '์ˆ ์ทจํ•œ ํƒœ์—ฐ' else 'model'])
self.flowgenerator.decoder.skip() if model_variant != '์ˆ ์ทจํ•œ ํƒœ์—ฐ' else None
self.flowgenerator.eval()
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹':
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00664000.pt'
elif model_variant == 'KSS':
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
elif model_variant == '์ˆ ์ทจํ•œ ํƒœ์—ฐ':
last_chpt2 = './log/Taeyeon/HiFi_GAN_400000.pt'
check_point = torch.load(last_chpt2, map_location = device)
self.voicegenerator.load_state_dict(check_point['gen_model'])
self.voicegenerator.eval()
self.voicegenerator.remove_weight_norm()
def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
filters = '([.,!?])' if st.session_state != '์ˆ ์ทจํ•œ ํƒœ์—ฐ' else '([,])'
sentence = re.sub(re.compile(filters), '', input_text)
x = text_to_sequence(sentence)
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
with torch.no_grad():
(y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale) if st.session_state.model_variant != "ํƒœ์—ฐ" else self.flowgenerator(x, x_length, inference = True, noise_scale = noise_scale, length_scale = length_scale)
y = self.voicegenerator(y_gen_tst)
audio = y.squeeze() * 32768.0
voice = audio.cpu().numpy().astype('int16')
return voice
def init_session_state():
if "init_model" not in st.session_state:
st.session_state.init_model = True
st.session_state.model_variant = "์ˆ ์ทจํ•œ ํƒœ์—ฐ"
st.session_state.TTS = TTS("์ˆ ์ทจํ•œ ํƒœ์—ฐ")
def update_model():
if st.session_state.model_variant == "KSS":
st.session_state.TTS = TTS("KSS")
elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹":
st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹")
elif st.session_state.model_variant == '์ˆ ์ทจํ•œ ํƒœ์—ฐ':
st.session_state.TTS = TTS("์ˆ ์ทจํ•œ ํƒœ์—ฐ")
def update_session_state(state_id, state_value):
st.session_state[f"{state_id}"] = state_value
def centered_text(input_text, mode = "h1",):
st.markdown(
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
init_session_state()
centered_text("๐Ÿ”‰ ์†Œ์‹  Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")
mode = "p"
st.markdown(
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by <a href= 'https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset'>KSS Dataset</a>. The voice \"๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
unsafe_allow_html = True
)
st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)
with col1:
input_text = st.text_input(
"ํ•œ๊ธ€๋กœ๋งŒ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”",
value = "๋ฐฅ์€ ๋จน๊ณ  ๋‹ค๋…€?",
)
with col2:
model_variant = st.selectbox("๋ชฉ์†Œ๋ฆฌ ์„ ํƒํ•ด์ฃผ์„ธ์š”", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹", "์ˆ ์ทจํ•œ ํƒœ์—ฐ"], index = 1)
button_change = st.button("Change Vocie")
if button_change == True:
if model_variant != st.session_state.model_variant:
with st.spinner('Wait for it...'):
update_session_state("model_variant", model_variant)
update_model()
st.success('Done!', icon="โœ…")
noise_scale = st.slider('noise๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 0.3, step = 0.1)
length_scale = st.slider('์†๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 1., step = 0.1)
button_gen = st.button("Generate Voice")
if button_gen == True:
voice = st.session_state.TTS.inference(input_text, noise_scale, length_scale)
st.audio(voice,sample_rate = 22050)
st.caption("Generated Voice by" + st.session_state.model_variant)
st.balloons()