Glow-HiFi-TTS / app.py
marigold334's picture
Upload 10 files
41989ff
raw
history blame
4.88 kB
import streamlit as st
import soundfile as sf
import timeit
import uuid
import os
import torch
from datautils import *
from model import Generator as Glow_model
from utils import scan_checkpoint, plot_mel, plot_alignment
from Hmodel import Generator as GAN_model
MAX_WAV_VALUE = 32768.0
device = torch.device('cuda:0')
torch.cuda.manual_seed(1234)
name = '1038_eunsik_01'
# Nix
from nix.models.TTS import NixTTSInference
def init_session_state():
# Model
if "init_model" not in st.session_state:
st.session_state.init_model = True
st.session_state.model_variant = "KSS"
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
def update_model():
if st.session_state.model_variant == "KSS":
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-v0.1")
elif st.session_state.model_variant == "은식":
st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
def update_session_state(state_id, state_value):
st.session_state[f"{state_id}"] = state_value
def centered_text(input_text, mode = "h1",):
st.markdown(
f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
def generate_voice(input_text,):
# TTS Inference
c, c_length, phoneme = st.session_state.TTS.tokenize(input_text)
voice = st.session_state.TTS.vocalize(c, c_length)
# Save audio (bug in Streamlit, can't play numpy array directly)
sf.write(f"cache_sound/{input_text}.wav", voice[0,0], 22050)
# Play audio
st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
os.remove(f"cache_sound/{input_text}.wav")
st.caption("Generated Voice")
st.set_page_config(
page_title = "μ†Œμ‹  Team Demo",
page_icon = "πŸ”‰",
)
init_session_state()
centered_text("πŸ”‰ μ†Œμ‹  Team Demo")
centered_text("mel generator : Glow-TTS, vocoder : HiFi-GAN", "h5")
st.write(" ")
mode = "p"
st.markdown(
f"<{mode} style='text-align: left;'><small>This is a demo trained by our vocie.&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; The voice \"KSS\" is traind 3 times \"은식\" is finetuned from \"KSS\" for 3 times &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; We got this deomoformat from Nix-TTS Interactive Demo</small></{mode}>",
unsafe_allow_html = True
)
st.write(" ")
st.write(" ")
col1, col2 = st.columns(2)
with col1:
input_text = st.text_input(
"ν•œκΈ€λ‘œλ§Œ μž…λ ₯ν•΄μ£Όμ„Έμš”",
value = "λ”₯λŸ¬λ‹μ€ 정말 μž¬λ°Œμ–΄!",
)
with col2:
model_variant = st.selectbox("λͺ©μ†Œλ¦¬ μ„ νƒν•΄μ£Όμ„Έμš”", options = ["KSS", "은식"], index = 1)
if model_variant != st.session_state.model_variant:
# Update variant choice
update_session_state("model_variant", model_variant)
# Re-load model
update_model()
button_gen = st.button("Generate Voice")
if button_gen == True:
generate_voice(input_text)
class TTS:
def __init__(self, model_variant):
self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
self.voicegenerator = GAN_model()
if model_variant == '은식':
last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
check_point = torch.load(last_chpt1)
self.flowgenerator.load_state_dict(check_point['generator'])
self.flowgenerator.decoder.skip()
self.flowgenerator.eval()
if model_variant == '은식':
last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
check_point = torch.load(last_chpt2)
self.voicegenerator.load_state_dict(check_point['gen_model'])
self.voicegenerator.eval()
self.voicegenerator.remove_weight_norm()
def inference(self, input_text):
x = text_to_sequence(sentence)
filters = '([.,!?])'
sentence = re.sub(re.compile(filters), '', text)
x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
with torch.no_grad():
noise_scale = .667
length_scale = 1.0
(y_gen_tst, *_), *_, (attn_gen, *_) = flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
y = voicegenerator(y_gen_tst)
audio = y.squeeze() * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
output_file = os.path.join(out_dir, 'gen_'+text[:3]+'.wav')
write(output_file, 22050, audio)
print(f'{text} is stored in {out_dir}')
return voice
plot_mel(y_gen_tst[0].data.cpu().numpy())
plot_alignment(attn_gen[0,0].data.cpu().numpy(), sequence_to_text(x[0].data.cpu().numpy()))
ipd.display(fig1,fig2)
ipd.Audio(filename=output_file)