Spaces:
Runtime error
Runtime error
File size: 4,682 Bytes
41989ff 8fde97d 41989ff dcef047 41989ff 8fde97d dcef047 371ba49 dcef047 764c666 8fde97d fa050b7 4936e8e 8fde97d 764c666 1258aa9 fa050b7 1258aa9 4936e8e 8fde97d 87218c5 8fde97d 41989ff 764c666 41989ff 8fde97d 764c666 41989ff 764c666 41989ff b318680 41989ff 764c666 b318680 41989ff b318680 87218c5 486e21a b318680 41989ff 87218c5 fa050b7 41989ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import streamlit as st
import soundfile as sf
import os, re
import torch
from datautils import *
from model import Generator as Glow_model
from Hmodel import Generator as GAN_model
st.set_page_config(
page_title = "소신 Team Demo",
page_icon = "🔉",
)
class TTS:
def __init__(self, model_variant):
global device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device)
self.voicegenerator = GAN_model().to(device)
if model_variant == '감기걸린 은식':
name = '1038_eunsik_01'
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
elif model_variant == 'KSS':
last_chpt1 = './log/KSS/Glow_TTS_00280641.pt'
check_point = torch.load(last_chpt1, map_location = device)
self.flowgenerator.load_state_dict(check_point['generator'])
self.flowgenerator.decoder.skip()
self.flowgenerator.eval()
if model_variant == '감기걸린 은식':
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
elif model_variant == 'KSS':
last_chpt2 = './log/KSS/HiFi_GAN_00135000.pt'
check_point = torch.load(last_chpt2, map_location = device)
self.voicegenerator.load_state_dict(check_point['gen_model'])
self.voicegenerator.eval()
self.voicegenerator.remove_weight_norm()
def inference(self, input_text, noise_scale = 0.667, length_scale = 1.0):
filters = '([.,!?])'
sentence = re.sub(re.compile(filters), '', input_text)
x = text_to_sequence(sentence)
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
with torch.no_grad():
(y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
y = self.voicegenerator(y_gen_tst)
audio = y.squeeze() * 32768.0
voice = audio.cpu().numpy().astype('int16')
return voice
def init_session_state():
# Model
if "init_model" not in st.session_state:
st.session_state.init_model = True
st.session_state.model_variant = "감기걸린 은식"
st.session_state.TTS = TTS("감기걸린 은식")
def update_model():
if st.session_state.model_variant == "KSS":
st.session_state.TTS = TTS("KSS")
elif st.session_state.model_variant == "감기걸린 은식":
st.session_state.TTS = TTS("감기걸린 은식")
def update_session_state(state_id, state_value):
st.session_state[f"{state_id}"] = state_value
def centered_text(input_text, mode = "h1",):
st.markdown(
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
init_session_state()
centered_text("🔉 소신 Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")
mode = "p"
st.markdown(
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind by <a href= 'https://www.kaggle.com/datasets/bryanpark/korean-single-speaker-speech-dataset'>KSS Dataset</a>. The voice \"감기걸린 은식\" is trained from pre-trained \"KSS\". We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
unsafe_allow_html = True
)
st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)
with col1:
input_text = st.text_input(
"한글로만 입력해주세요",
value = "밥은 먹고 다녀",
)
with col2:
model_variant = st.selectbox("목소리 선택해주세요", options = ["KSS", "감기걸린 은식"], index = 1)
button_change = st.button("Change Vocie")
if button_change == True:
if model_variant != st.session_state.model_variant:
# Update variant choice
update_session_state("model_variant", model_variant)
# Re-load model
update_model()
st.snow()
noise_scale = st.slider('noise를 추가합니다.', 0., 2., value = 0.33, step = 0.01)
length_scale = st.slider('속도를 조절합니다.', 0., 2., value = 1., step = 0.01)
button_gen = st.button("Generate Voice")
if button_gen == True:
voice = st.session_state.TTS.inference(input_text, noise_scale, length_scale)
st.audio(voice,sample_rate = 22050)
st.caption("Generated Voice by" + st.session_state.model_variant)
st.balloons()
|