Spaces:
Runtime error
Runtime error
import streamlit as st | |
import soundfile as sf | |
import os, re | |
import torch | |
from datautils import * | |
from model import Generator as Glow_model | |
from Hmodel import Generator as GAN_model | |
st.set_page_config( | |
page_title = "μμ Team Demo", | |
page_icon = "π", | |
) | |
class TTS: | |
def __init__(self, model_variant): | |
global device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch.cuda.manual_seed(1234) if torch.cuda.is_available() else None | |
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6).to(device) | |
self.voicegenerator = GAN_model().to(device) | |
if model_variant == 'μμ': | |
name = '1038_eunsik_01' | |
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt' | |
check_point = torch.load(last_chpt1, map_location = device) | |
self.flowgenerator.load_state_dict(check_point['generator']) | |
self.flowgenerator.decoder.skip() | |
self.flowgenerator.eval() | |
if model_variant == 'μμ': | |
name = '1038_eunsik_01' | |
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt' | |
check_point = torch.load(last_chpt2, map_location = device) | |
self.voicegenerator.load_state_dict(check_point['gen_model']) | |
self.voicegenerator.eval() | |
self.voicegenerator.remove_weight_norm() | |
def inference(self, input_text): | |
filters = '([.,!?])' | |
sentence = re.sub(re.compile(filters), '', input_text) | |
x = text_to_sequence(sentence) | |
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long() | |
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
noise_scale = .667 | |
length_scale = 1.0 | |
(y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale) | |
y = self.voicegenerator(y_gen_tst) | |
audio = y.squeeze() * 32768.0 | |
voice = audio.cpu().numpy().astype('int16') | |
return voice | |
def init_session_state(): | |
# Model | |
if "init_model" not in st.session_state: | |
st.session_state.init_model = True | |
st.session_state.model_variant = "μμ" | |
st.session_state.TTS = TTS("μμ") | |
def update_model(): | |
if st.session_state.model_variant == "KSS": | |
st.session_state.TTS = TTS("KSS") | |
elif st.session_state.model_variant == "μμ": | |
st.session_state.TTS = TTS("μμ") | |
def update_session_state(state_id, state_value): | |
st.session_state[f"{state_id}"] = state_value | |
def centered_text(input_text, mode = "h1",): | |
st.markdown( | |
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True) | |
def generate_voice(input_text): | |
# TTS Inference | |
voice = st.session_state.TTS.inference(input_text) | |
# Save audio (bug in Streamlit, can't play numpy array directly) | |
sf.write(f"cache_sound/{input_text}.wav", voice, 22050) | |
# Play audio | |
st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav") | |
os.remove(f"cache_sound/{input_text}.wav") | |
st.caption("Generated Voice") | |
init_session_state() | |
centered_text("π μμ Team Demo") | |
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5") | |
st.write(" ") | |
mode = "p" | |
st.markdown( | |
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie. The voice \"KSS\" is traind 3 times \"μμ\" is finetuned from \"KSS\" for 3 times We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>", | |
unsafe_allow_html = True | |
) | |
st.write(" ") | |
st.write(" ") | |
col1, col2 = st.columns(2) | |
with col1: | |
input_text = st.text_input( | |
"νκΈλ‘λ§ μ λ ₯ν΄μ£ΌμΈμ", | |
value = "λ₯λ¬λμ μ λ§ μ¬λ°μ΄!", | |
) | |
with col2: | |
model_variant = st.selectbox("λͺ©μ리 μ νν΄μ£ΌμΈμ", options = ["KSS", "μμ"], index = 1) | |
if model_variant != st.session_state.model_variant: | |
# Update variant choice | |
update_session_state("model_variant", model_variant) | |
# Re-load model | |
update_model() | |
button_gen = st.button("Generate Voice") | |
if button_gen == True: | |
generate_voice(input_text) | |