import json import re import numpy as np import IPython.display as ipd import torch import commons import utils from models import SynthesizerTrn from text.symbols import symbols from text import text_to_sequence import gradio as gr import time import json import datetime import os import pickle from scipy.io.wavfile import write import librosa import romajitable from mel_processing import spectrogram_torch import soundfile as sf from scipy import signal class VitsGradio: def __init__(self): self.lan = ["中文","日文","自动"] self.modelPaths = [] for root,dirs,files in os.walk("checkpoints"): for dir in dirs: self.modelPaths.append(dir) with gr.Blocks() as self.Vits: with gr.Tab("小说合成"): with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): self.Text = gr.File(label="Text") self.audio_path = gr.TextArea(label="音频路径",lines=1,value = 'audiobook/chapter.wav') btnbook = gr.Button("小说合成") btnbook.click(self.tts_fn, inputs=[self.Text,self.audio_path]) with gr.Tab("TTS设定"): with gr.Row(): with gr.Column(): with gr.Row(): with gr.Column(): self.input1 = gr.Dropdown(label = "模型", choices = self.modelPaths, value = self.modelPaths[0], type = "value") self.input2 = gr.Dropdown(label="Language", choices=self.lan, value="自动", interactive=True) self.input3 = gr.Dropdown(label="Speaker", choices=list(range(1001)), value=0, interactive=True) self.input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.6) self.input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.667) self.input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1) statusa = gr.TextArea() btnVC = gr.Button("完成vits TTS端设定") btnVC.click(self.create_tts_fn, inputs=[self.input1, self.input2, self.input3, self.input4, self.input5, self.input6], outputs = [statusa]) def is_japanese(self,string): for ch in string: if ord(ch) > 0x3040 and ord(ch) < 0x30FF: return True return False def is_english(self,string): import re pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$') if pattern.fullmatch(string): return True else: return False def get_text(self,text, hps, cleaned=False): if cleaned: text_norm = text_to_sequence(text, self.hps_ms.symbols, []) else: text_norm = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners) if self.hps_ms.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = torch.LongTensor(text_norm) return text_norm def get_label(self,text, label): if f'[{label}]' in text: return True, text.replace(f'[{label}]', '') else: return False, text def sle(self,language,text): text = text.replace('\n','。').replace(' ',',') if language == "中文": tts_input1 = "[ZH]" + text + "[ZH]" return tts_input1 elif language == "自动": tts_input1 = f"[JA]{text}[JA]" if self.is_japanese(text) else f"[ZH]{text}[ZH]" return tts_input1 elif language == "日文": tts_input1 = "[JA]" + text + "[JA]" return tts_input1 def create_tts_fn(self,path, input2, input3, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ): self.language = input2 self.speaker_id = int(input3) self.n_scale = n_scale self.n_scale_w = n_scale_w self.l_scale = l_scale self.dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.hps_ms = utils.get_hparams_from_file(f"checkpoints/{path}/config.json") self.n_speakers = self.hps_ms.data.n_speakers if 'n_speakers' in self.hps_ms.data.keys() else 0 self.n_symbols = len(self.hps_ms.symbols) if 'symbols' in self.hps_ms.keys() else 0 self.net_g_ms = SynthesizerTrn( self.n_symbols, self.hps_ms.data.filter_length // 2 + 1, self.hps_ms.train.segment_size // self.hps_ms.data.hop_length, n_speakers=self.n_speakers, **self.hps_ms.model).to(self.dev) _ = self.net_g_ms.eval() _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", self.net_g_ms) return 'success' def transfer(self,text): text = re.sub("<[^>]*>","",text) result_list = re.split(r'\n', text) final_list = [] for j in result_list: result_list2 = re.split(r'。|!|——|:|;|……|——|。|!', j) for i in result_list2: if self.is_english(i): i = romajitable.to_kana(i).katakana for m in range(20): i = i.replace('\n','').replace(' ','').replace('……','。').replace('…','。').replace('还','孩').replace('“','').replace('”','').replace('!','。').replace('」','').replace('「','') #Current length of single sentence: 50 if len(i)>1: if len(i) > 50: try: cur_list = re.split(r'。|!|——|,|:', i) for i in cur_list: if len(i)>1: final_list.append(i+'。') except: pass else: final_list.append(i) final_list = [x for x in final_list if x != ''] return final_list def tts_fn(self,text,audio_path): with open(text.name, "r", encoding="utf-8") as f: text = f.read() a = ['【','[','(','(','〔'] b = ['】',']',')',')','〕'] for i in a: text = text.replace(i,'<') for i in b: text = text.replace(i,'>') final_list = self.transfer(text) split_list = [] while len(final_list) > 0: split_list.append(final_list[:1000]) final_list = final_list[1000:] c0 = 0 for lists in split_list: audio_fin = [] t = datetime.timedelta(seconds=0) c = 0 f1 = open(audio_path.replace('.wav',str(c0)+".srt"),'w',encoding='utf-8') for sentence in lists: try: c +=1 with torch.no_grad(): stn_tst = self.get_text(self.sle(self.language,sentence), self.hps_ms, cleaned=False) x_tst = stn_tst.unsqueeze(0).to(self.dev) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.dev) sid = torch.LongTensor([self.speaker_id]).to(self.dev) t1 = time.time() audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=self.n_scale, noise_scale_w=self.n_scale_w, length_scale=self.l_scale)[0][ 0, 0].data.cpu().float().numpy() t2 = time.time() spending_time = "第"+str(c)+"句的推理时间为:"+str(t2-t1)+"s" print(spending_time) time_start = str(t).split(".")[0] + "," + str(t.microseconds)[:3] last_time = datetime.timedelta(seconds=len(audio)/float(22050)) t+=last_time time_end = str(t).split(".")[0] + "," + str(t.microseconds)[:3] print(time_end) f1.write(str(c-1)+'\n'+time_start+' --> '+time_end+'\n'+sentence.replace('。','')+'\n\n') resampled_audio_data = signal.resample(audio, len(audio) * 2) audio_fin.append(resampled_audio_data) except: pass sf.write(audio_path.replace('.wav',str(c0)+'.wav'), np.concatenate(audio_fin), 44100, 'PCM_24') c0 += 1 file_path = audio_path.replace('.wav',str(c0)+".srt") if __name__ == '__main__': print("开始部署") grVits = VitsGradio() grVits.Vits.launch()