marigold334 commited on
Commit
e44a7c2
β€’
1 Parent(s): 39097f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -63
app.py CHANGED
@@ -1,37 +1,61 @@
1
  import streamlit as st
2
  import soundfile as sf
3
- import timeit
4
- import uuid
5
-
6
- import os
7
-
8
  import torch
9
-
10
  from datautils import *
11
  from model import Generator as Glow_model
12
- from utils import scan_checkpoint, plot_mel, plot_alignment
13
  from Hmodel import Generator as GAN_model
14
 
15
- MAX_WAV_VALUE = 32768.0
16
- device = torch.device('cuda:0')
17
- torch.cuda.manual_seed(1234)
18
- name = '1038_eunsik_01'
19
 
20
- # Nix
21
- from nix.models.TTS import NixTTSInference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  def init_session_state():
24
  # Model
25
  if "init_model" not in st.session_state:
26
  st.session_state.init_model = True
27
- st.session_state.model_variant = "KSS"
28
- st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
29
 
30
  def update_model():
31
  if st.session_state.model_variant == "KSS":
32
- st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-v0.1")
33
  elif st.session_state.model_variant == "은식":
34
- st.session_state.TTS = NixTTSInference("assets/nix-ljspeech-sdp-v0.1")
35
 
36
  def update_session_state(state_id, state_value):
37
  st.session_state[f"{state_id}"] = state_value
@@ -40,19 +64,19 @@ def centered_text(input_text, mode = "h1",):
40
  st.markdown(
41
  f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
42
 
43
- def generate_voice(input_text,):
44
  # TTS Inference
45
- c, c_length, phoneme = st.session_state.TTS.tokenize(input_text)
46
- voice = st.session_state.TTS.vocalize(c, c_length)
47
 
48
  # Save audio (bug in Streamlit, can't play numpy array directly)
49
- sf.write(f"cache_sound/{input_text}.wav", voice[0,0], 22050)
50
 
51
  # Play audio
52
  st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
53
  os.remove(f"cache_sound/{input_text}.wav")
54
  st.caption("Generated Voice")
55
-
 
56
  st.set_page_config(
57
  page_title = "μ†Œμ‹  Team Demo",
58
  page_icon = "πŸ”‰",
@@ -92,44 +116,3 @@ if button_gen == True:
92
  generate_voice(input_text)
93
 
94
 
95
- class TTS:
96
- def __init__(self, model_variant):
97
- self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
98
- self.voicegenerator = GAN_model()
99
- if model_variant == '은식':
100
- last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
101
- check_point = torch.load(last_chpt1)
102
- self.flowgenerator.load_state_dict(check_point['generator'])
103
- self.flowgenerator.decoder.skip()
104
- self.flowgenerator.eval()
105
- if model_variant == '은식':
106
- last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
107
- check_point = torch.load(last_chpt2)
108
- self.voicegenerator.load_state_dict(check_point['gen_model'])
109
- self.voicegenerator.eval()
110
- self.voicegenerator.remove_weight_norm()
111
-
112
- def inference(self, input_text):
113
- x = text_to_sequence(sentence)
114
- filters = '([.,!?])'
115
- sentence = re.sub(re.compile(filters), '', text)
116
- x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
117
- x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
118
-
119
- with torch.no_grad():
120
- noise_scale = .667
121
- length_scale = 1.0
122
- (y_gen_tst, *_), *_, (attn_gen, *_) = flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
123
- y = voicegenerator(y_gen_tst)
124
- audio = y.squeeze() * MAX_WAV_VALUE
125
- audio = audio.cpu().numpy().astype('int16')
126
-
127
- output_file = os.path.join(out_dir, 'gen_'+text[:3]+'.wav')
128
- write(output_file, 22050, audio)
129
- print(f'{text} is stored in {out_dir}')
130
-
131
- return voice
132
- plot_mel(y_gen_tst[0].data.cpu().numpy())
133
- plot_alignment(attn_gen[0,0].data.cpu().numpy(), sequence_to_text(x[0].data.cpu().numpy()))
134
- ipd.display(fig1,fig2)
135
- ipd.Audio(filename=output_file)
 
1
  import streamlit as st
2
  import soundfile as sf
3
+ import os, re
 
 
 
 
4
  import torch
 
5
  from datautils import *
6
  from model import Generator as Glow_model
 
7
  from Hmodel import Generator as GAN_model
8
 
9
+ device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
10
+ torch.cuda.manual_seed(1234) if torch.duda.is_available() else None
 
 
11
 
12
+ class TTS:
13
+ def __init__(self, model_variant):
14
+ self.flowgenerator = Glow_model(n_vocab = 70, h_c= 192, f_c = 768, f_c_dp = 256, out_c = 80, k_s = 3, k_s_dec = 5, heads=2, layers_enc = 6)
15
+ self.voicegenerator = GAN_model()
16
+ if model_variant == '은식':
17
+ name = '1038_eunsik_01'
18
+ last_chpt1 = './log/1038_eunsik_01/Glow_TTS_00289602.pt'
19
+ check_point = torch.load(last_chpt1)
20
+ self.flowgenerator.load_state_dict(check_point['generator'])
21
+ self.flowgenerator.decoder.skip()
22
+ self.flowgenerator.eval()
23
+ if model_variant == '은식':
24
+ name = '1038_eunsik_01'
25
+ last_chpt2 = './log/1038_eunsik_01/HiFI_GAN_00257000.pt'
26
+ check_point = torch.load(last_chpt2)
27
+ self.voicegenerator.load_state_dict(check_point['gen_model'])
28
+ self.voicegenerator.eval()
29
+ self.voicegenerator.remove_weight_norm()
30
+
31
+ def inference(self, input_text):
32
+ filters = '([.,!?])'
33
+ sentence = re.sub(re.compile(filters), '', input_text)
34
+ x = text_to_sequence(sentence)
35
+ x = torch.autograd.Variable(torch.tensor(x).unsqueeze(0)).to(device).long()
36
+ x_length = torch.tensor(x.shape[1]).unsqueeze(0).to(device)
37
+
38
+ with torch.no_grad():
39
+ noise_scale = .667
40
+ length_scale = 1.0
41
+ (y_gen_tst, *_), *_, (attn_gen, *_) = self.flowgenerator(x, x_length, gen = True, noise_scale = noise_scale, length_scale = length_scale)
42
+ y = self.voicegenerator(y_gen_tst)
43
+ audio = y.squeeze() * 32768.0
44
+ voice = audio.cpu().numpy().astype('int16')
45
+ return voice
46
 
47
  def init_session_state():
48
  # Model
49
  if "init_model" not in st.session_state:
50
  st.session_state.init_model = True
51
+ st.session_state.model_variant = "은식"
52
+ st.session_state.TTS = TTS("은식")
53
 
54
  def update_model():
55
  if st.session_state.model_variant == "KSS":
56
+ st.session_state.TTS = TTS("KSS")
57
  elif st.session_state.model_variant == "은식":
58
+ st.session_state.TTS = TTS("은식")
59
 
60
  def update_session_state(state_id, state_value):
61
  st.session_state[f"{state_id}"] = state_value
 
64
  st.markdown(
65
  f"<{mode} style='text-align: center;'>{input_text}</{mode}>", unsafe_allow_html = True)
66
 
67
+ def generate_voice(input_text):
68
  # TTS Inference
69
+ voice = st.session_state.TTS.inference(input_text)
 
70
 
71
  # Save audio (bug in Streamlit, can't play numpy array directly)
72
+ sf.write(f"cache_sound/{input_text}.wav", voice, 22050)
73
 
74
  # Play audio
75
  st.audio(f"cache_sound/{input_text}.wav", format = "audio/wav")
76
  os.remove(f"cache_sound/{input_text}.wav")
77
  st.caption("Generated Voice")
78
+
79
+
80
  st.set_page_config(
81
  page_title = "μ†Œμ‹  Team Demo",
82
  page_icon = "πŸ”‰",
 
116
  generate_voice(input_text)
117
 
118