import streamlit as st import soundfile as sf import os, re import torch from datautils import * from model import Generator as Glow_model from Hmodel import Generator as GAN_model st.set_page_config( page_title = "์†Œ์‹  Team Demo", page_icon = "๐Ÿ”‰", ) class TTS: def __init__(self, model_variant): global device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None self.flowgenerator = Glow_model(n_vocab = 70 if model_variant != 'ํƒœ์—ฐ' else 73, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device) self.voicegenerator = GAN_model().to(device) if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹': name = '1038_eunsik_01' last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt' elif model_variant == 'KSS': last_chpt1 = './log/KSS/Glow_TTS_00280641.pt' elif model_variant == 'ํƒœ์—ฐ': last_chpt1 = './log/Taeyeon/Glow_TTS_400000.pt' check_point = torch.load(last_chpt1, map_location = device) self.flowgenerator.load_state_dict(check_point['generator' if model_variant != 'ํƒœ์—ฐ' else 'model']) self.flowgenerator.decoder.skip() self.flowgenerator.eval() if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹': last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt' elif model_variant == 'KSS': last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt' elif model_variant == 'ํƒœ์—ฐ': last_chpt2 = './log/Taeyeon/HiFi_GAN_337000.pt' check_point = torch.load(last_chpt2, map_location = device) self.voicegenerator.load_state_dict(check_point['gen_model']) self.voicegenerator.eval() self.voicegenerator.remove_weight_norm() def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0): filters = '([.,!?])' if st.session_state != 'ํƒœ์—ฐ' else '([,])' sentence = re.sub(re.compile(filters), '', input_text) x = text_to_sequence(sentence) x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long() x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device) with torch.no_grad(): (y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale) y = self.voicegenerator(y_gen_tst) audio = y.squeeze() * 32768.0 voice = audio.cpu().numpy().astype('int16') return voice def init_session_state(): # Model if "init_model" not in st.session_state: st.session_state.init_model = True st.session_state.model_variant = "ํƒœ์—ฐ" st.session_state.TTS = TTS("ํƒœ์—ฐ") def update_model(): if st.session_state.model_variant == "KSS": st.session_state.TTS = TTS("KSS") elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹": st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹") elif st.seesion_state.model_varaiant == 'ํƒœ์—ฐ': st.session_state.TTS = TTS("ํƒœ์—ฐ") def update_session_state(state_id, state_value): st.session_state[f"{state_id}"] = state_value def centered_text(input_text, mode = "h1",): st.markdown( f"<{mode} style='text-align: center;'>{input_text}", unsafe_allow_html = True) init_session_state() centered_text("๐Ÿ”‰ ์†Œ์‹  Team Demo") centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5") st.write(" ") mode = "p" st.markdown( f"<{mode} style='text-align: left;'>This is a demo trained by our vocie. The voice \"KSS\" is traind by KSS Dataset. The voice \"๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo", unsafe_allow_html = True ) st.write(" ") st.write(" ") col1, col2 = st.columns(2) with col1: input_text = st.text_input( "ํ•œ๊ธ€๋กœ๋งŒ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”", value = "๋ฐฅ์€ ๋จน๊ณ  ๋‹ค๋…€?", ) with col2: model_variant = st.selectbox("๋ชฉ์†Œ๋ฆฌ ์„ ํƒํ•ด์ฃผ์„ธ์š”", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์€์‹", "ํƒœ์—ฐ"], index = 1) button_change = st.button("Change Vocie") if button_change == True: if model_variant != st.session_state.model_variant: with st.spinner('Wait for it...'): update_session_state("model_variant", model_variant) update_model() st.success('Done!', icon="โœ…") noise_scale = st.slider('noise๋ฅผ ์ถ”๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 0.33, step = 0.01) length_scale = st.slider('์†๋„๋ฅผ ์กฐ์ ˆํ•ฉ๋‹ˆ๋‹ค.', 0., 2., value = 1., step = 0.01) button_gen = st.button("Generate Voice") if button_gen == True: voice = st.session_state.TTS.inference(input_text, noise_scale, length_scale) st.audio(voice,sample_rate = 22050) st.caption("Generated Voice by" + st.session_state.model_variant) st.balloons()