Spaces:
Runtime error
Runtime error
File size: 5,460 Bytes
41989ff 8fde97d 41989ff 8ebd890 41989ff dcef047 41989ff 8fde97d dcef047 371ba49 dcef047 98032bd dcef047 764c666 8fde97d fa050b7 d8c4b79 ddbc9f0 4936e8e d8c4b79 8fde97d 764c666 b1ee5a2 fa050b7 1258aa9 d8c4b79 ddbc9f0 4936e8e 8fde97d 87218c5 d8c4b79 c5fff6b 8fde97d 7b55d77 8fde97d 41989ff d8c4b79 41989ff 8fde97d 764c666 d8c4b79 41989ff 764c666 41989ff e6cd225 41989ff b1ee5a2 b318680 41989ff e6cd225 b318680 d8c4b79 b318680 41989ff 87218c5 fa050b7 41989ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import streamlit as st
import soundfile as sf
import os, re
import torch
from datautils import *
from model import Generator as Glow_model
from Hmodel import Generator as GAN_model
from Tmodel import GlowTTS as T_Glow_model
st.set_page_config(
page_title = "์์ Team Demo",
page_icon = "๐",
)
class TTS:
def __init__(self, model_variant):
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
self.flowgenerator = Glow_model(n_vocab = 70 , h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device) if model_variant != '์ ์ทจํ ํ์ฐ' else T_Glow_model().to(device)
self.voicegenerator = GAN_model().to(device)
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์':
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
elif model_variant == 'KSS':
last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
elif model_variant == '์ ์ทจํ ํ์ฐ':
last_chpt1 = './log/Taeyeon/Glow_TTS_337000.pt'
check_point = torch.load(last_chpt1, map_location = device)
self.flowgenerator.load_state_dict(check_point['generator' if model_variant != '์ ์ทจํ ํ์ฐ' else 'model'])
self.flowgenerator.decoder.skip() if model_variant != '์ ์ทจํ ํ์ฐ' else None
self.flowgenerator.eval()
if model_variant == '๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์':
last_chpt2 = './log/1038_eunsik_01/HiFi_GAN_00664000.pt'
elif model_variant == 'KSS':
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
elif model_variant == '์ ์ทจํ ํ์ฐ':
last_chpt2 = './log/Taeyeon/HiFi_GAN_400000.pt'
check_point = torch.load(last_chpt2, map_location = device)
self.voicegenerator.load_state_dict(check_point['gen_model'])
self.voicegenerator.eval()
self.voicegenerator.remove_weight_norm()
def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
filters = '([.,!?])' if st.session_state != '์ ์ทจํ ํ์ฐ' else '([,])'
sentence = re.sub(re.compile(filters), '', input_text)
x = text_to_sequence(sentence)
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
with torch.no_grad():
(y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale) if st.session_state.model_variant != "์ ์ทจํ ํ์ฐ" else self.flowgenerator(x, x_length, inference = True, noise_scale = noise_scale, length_scale = length_scale)
y = self.voicegenerator(y_gen_tst)
audio = y.squeeze() * 32768.0
voice = audio.cpu().numpy().astype('int16')
return voice
def init_session_state():
if "init_model" not in st.session_state:
st.session_state.init_model = True
st.session_state.model_variant = "์ ์ทจํ ํ์ฐ"
st.session_state.TTS = TTS("์ ์ทจํ ํ์ฐ")
def update_model():
if st.session_state.model_variant == "KSS":
st.session_state.TTS = TTS("KSS")
elif st.session_state.model_variant == "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์":
st.session_state.TTS = TTS("๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์")
elif st.session_state.model_variant == '์ ์ทจํ ํ์ฐ':
st.session_state.TTS = TTS("์ ์ทจํ ํ์ฐ")
def update_session_state(state_id, state_value):
st.session_state[f"{state_id}"] = state_value
def centered_text(input_text, mode = "h1",):
st.markdown(
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
init_session_state()
centered_text("๐ ์์ Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")
mode = "p"
st.markdown(
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by <a href= 'https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset'>KSS Dataset</a>. The voice \"๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
unsafe_allow_html = True
)
st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)
with col1:
input_text = st.text_input(
"ํ๊ธ๋ก๋ง ์
๋ ฅํด์ฃผ์ธ์",
value = "๋ฐฅ์ ๋จน๊ณ ๋ค๋
?",
)
with col2:
model_variant = st.selectbox("๋ชฉ์๋ฆฌ ์ ํํด์ฃผ์ธ์", options = ["KSS", "๊ฐ๊ธฐ๊ฑธ๋ฆฐ ์์", "์ ์ทจํ ํ์ฐ"], index = 2)
button_change = st.button("Change Vocie")
if button_change == True:
if model_variant != st.session_state.model_variant:
with st.spinner('Wait for it...'):
update_session_state("model_variant", model_variant)
update_model()
st.success('Done!', icon="โ
")
noise_scale = st.slider('noise๋ฅผ ์ถ๊ฐํฉ๋๋ค.', 0., 2., value = 0.3, step = 0.1)
length_scale = st.slider('์๋๋ฅผ ์กฐ์ ํฉ๋๋ค.', 0., 2., value = 1., step = 0.1)
button_gen = st.button("Generate Voice")
if button_gen == True:
voice = st.session_state.TTS.inference(input_text, noise_scale, length_scale)
st.audio(voice,sample_rate = 22050)
st.caption("Generated Voice by" + st.session_state.model_variant)
st.balloons()
|