Mahiruoshi commited on
Commit
5c5159a
1 Parent(s): a46a731

Upload 2 files

Browse files
Files changed (2) hide show
  1. main.py +241 -0
  2. 部署流程.md +9 -0
main.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import gradio as gr
3
+ import torch
4
+ import unicodedata
5
+ import commons
6
+ import utils
7
+ import pathlib
8
+ from models import SynthesizerTrn
9
+ from text import text_to_sequence
10
+ import time
11
+ import os
12
+ import io
13
+ from scipy.io.wavfile import write
14
+ from flask import Flask, request
15
+ from threading import Thread
16
+ import openai
17
+ import requests
18
+ import json
19
+ import soundfile as sf
20
+ from scipy import signal
21
+ class VitsGradio:
22
+ def __init__(self):
23
+ self.lan = ["中文","日文","自动"]
24
+ self.chatapi = ["gpt-3.5-turbo","gpt3"]
25
+ self.modelPaths = []
26
+ for root,dirs,files in os.walk("checkpoints"):
27
+ for dir in dirs:
28
+ self.modelPaths.append(dir)
29
+ with gr.Blocks() as self.Vits:
30
+ with gr.Tab("调试用"):
31
+ with gr.Row():
32
+ with gr.Column():
33
+ with gr.Row():
34
+ with gr.Column():
35
+ self.text = gr.TextArea(label="Text", value="你好")
36
+ with gr.Accordion(label="测试api", open=False):
37
+ self.local_chat1 = gr.Checkbox(value=False, label="使用网址+文本进行模拟")
38
+ self.url_input = gr.TextArea(label="键入测试", value="http://127.0.0.1:8080/chat?Text=")
39
+ butto = gr.Button("模拟前端抓取语音文件")
40
+ btnVC = gr.Button("测试tts+对话程序")
41
+ with gr.Column():
42
+ output2 = gr.TextArea(label="回复")
43
+ output1 = gr.Audio(label="采样率22050")
44
+ output3 = gr.outputs.File(label="44100hz: output.wav")
45
+ butto.click(self.Simul, inputs=[self.text, self.url_input], outputs=[output2,output3])
46
+ btnVC.click(self.tts_fn, inputs=[self.text], outputs=[output1,output2])
47
+ with gr.Tab("控制面板"):
48
+ with gr.Row():
49
+ with gr.Column():
50
+ with gr.Row():
51
+ with gr.Column():
52
+ self.api_input1 = gr.TextArea(label="输入api-key或本地存储说话模型的路径", value="https://platform.openai.com/account/api-keys")
53
+ with gr.Accordion(label="chatbot选择", open=False):
54
+ self.api_input2 = gr.Checkbox(value=True, label="采用gpt3.5")
55
+ self.local_chat1 = gr.Checkbox(value=False, label="启动本地chatbot")
56
+ self.local_chat2 = gr.Checkbox(value=True, label="是否量化")
57
+ res = gr.TextArea()
58
+ Botselection = gr.Button("完成chatbot设定")
59
+ Botselection.click(self.check_bot, inputs=[self.api_input1,self.api_input2,self.local_chat1,self.local_chat2], outputs = [res])
60
+ self.input1 = gr.Dropdown(label = "模型", choices = self.modelPaths, value = self.modelPaths[0], type = "value")
61
+ self.input2 = gr.Dropdown(label="Language", choices=self.lan, value="自动", interactive=True)
62
+ with gr.Column():
63
+ btnVC = gr.Button("完成vits TTS端设定")
64
+ self.input3 = gr.Dropdown(label="Speaker", choices=list(range(101)), value=0, interactive=True)
65
+ self.input4 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声比例(noise scale),以控制情感", value=0.267)
66
+ self.input5 = gr.Slider(minimum=0, maximum=1.0, label="更改噪声偏差(noise scale w),以控制音素长短", value=0.7)
67
+ self.input6 = gr.Slider(minimum=0.1, maximum=10, label="duration", value=1)
68
+ statusa = gr.TextArea()
69
+ btnVC.click(self.create_tts_fn, inputs=[self.input1, self.input2, self.input3, self.input4, self.input5, self.input6], outputs = [statusa])
70
+
71
+ def Simul(self,text,url_input):
72
+ web = url_input + text
73
+ res = requests.get(web)
74
+ music = res.content
75
+ with open('output.wav', 'wb') as code:
76
+ code.write(music)
77
+ file_path = "output.wav"
78
+ return web,file_path
79
+
80
+
81
+ def chatgpt(self,text):
82
+ self.messages.append({"role": "user", "content": text},)
83
+ chat = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages= self.messages)
84
+ reply = chat.choices[0].message.content
85
+ return reply
86
+
87
+ def ChATGLM(self,text):
88
+ if text == 'clear':
89
+ self.history = []
90
+ response, new_history = self.model.chat(self.tokenizer, text, self.history)
91
+ response = response.replace(" ",'').replace("\n",'.')
92
+ self.history = new_history
93
+ return response
94
+
95
+ def gpt3_chat(self,text):
96
+ call_name = "Waifu"
97
+ openai.api_key = args.key
98
+ identity = ""
99
+ start_sequence = '\n'+str(call_name)+':'
100
+ restart_sequence = "\nYou: "
101
+ if 1 == 1:
102
+ prompt0 = text #当期prompt
103
+ if text == 'quit':
104
+ return prompt0
105
+ prompt = identity + prompt0 + start_sequence
106
+ response = openai.Completion.create(
107
+ model="text-davinci-003",
108
+ prompt=prompt,
109
+ temperature=0.5,
110
+ max_tokens=1000,
111
+ top_p=1.0,
112
+ frequency_penalty=0.5,
113
+ presence_penalty=0.0,
114
+ stop=["\nYou:"]
115
+ )
116
+ return response['choices'][0]['text'].strip()
117
+
118
+ def check_bot(self,api_input1,api_input2,local_chat1,local_chat2):
119
+ if local_chat1:
120
+ from transformers import AutoTokenizer, AutoModel
121
+ self.tokenizer = AutoTokenizer.from_pretrained(api_input1, trust_remote_code=True)
122
+ if local_chat2:
123
+ self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True).half().quantize(4).cuda()
124
+ else:
125
+ self.model = AutoModel.from_pretrained(api_input1, trust_remote_code=True)
126
+ self.history = []
127
+ else:
128
+ self.messages = []
129
+ openai.api_key = api_input1
130
+ return "Finished"
131
+
132
+ def is_japanese(self,string):
133
+ for ch in string:
134
+ if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
135
+ return True
136
+ return False
137
+
138
+ def is_english(self,string):
139
+ import re
140
+ pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$')
141
+ if pattern.fullmatch(string):
142
+ return True
143
+ else:
144
+ return False
145
+
146
+
147
+
148
+ def get_text(self,text, hps, cleaned=False):
149
+ if cleaned:
150
+ text_norm = text_to_sequence(text, self.hps_ms.symbols, [])
151
+ else:
152
+ text_norm = text_to_sequence(text, self.hps_ms.symbols, self.hps_ms.data.text_cleaners)
153
+ if self.hps_ms.data.add_blank:
154
+ text_norm = commons.intersperse(text_norm, 0)
155
+ text_norm = torch.LongTensor(text_norm)
156
+ return text_norm
157
+
158
+
159
+ def get_label(self,text, label):
160
+ if f'[{label}]' in text:
161
+ return True, text.replace(f'[{label}]', '')
162
+ else:
163
+ return False, text
164
+
165
+ def sle(self,language,text):
166
+ text = text.replace('\n','。').replace(' ',',')
167
+ if language == "中文":
168
+ tts_input1 = "[ZH]" + text + "[ZH]"
169
+ return tts_input1
170
+ elif language == "自动":
171
+ tts_input1 = f"[JA]{text}[JA]" if self.is_japanese(text) else f"[ZH]{text}[ZH]"
172
+ return tts_input1
173
+ elif language == "日文":
174
+ tts_input1 = "[JA]" + text + "[JA]"
175
+ return tts_input1
176
+
177
+ def create_tts_fn(self,path, input2, input3, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
178
+ self.language = input2
179
+ self.speaker_id = int(input3)
180
+ self.n_scale = n_scale
181
+ self.n_scale_w = n_scale_w
182
+ self.l_scale = l_scale
183
+ self.dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
184
+ self.hps_ms = utils.get_hparams_from_file(f"checkpoints/{path}/config.json")
185
+ self.n_speakers = self.hps_ms.data.n_speakers if 'n_speakers' in self.hps_ms.data.keys() else 0
186
+ self.n_symbols = len(self.hps_ms.symbols) if 'symbols' in self.hps_ms.keys() else 0
187
+ self.net_g_ms = SynthesizerTrn(
188
+ self.n_symbols,
189
+ self.hps_ms.data.filter_length // 2 + 1,
190
+ self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
191
+ n_speakers=self.n_speakers,
192
+ **self.hps_ms.model).to(self.dev)
193
+ _ = self.net_g_ms.eval()
194
+ _ = utils.load_checkpoint(f"checkpoints/{path}/model.pth", self.net_g_ms)
195
+ return 'success'
196
+
197
+
198
+ def tts_fn(self,text):
199
+ if self.local_chat1:
200
+ text = self.chatgpt(text)
201
+ elif self.api_input2:
202
+ text = self.ChATGLM(text)
203
+ else:
204
+ text = self.gpt3_chat(text)
205
+ print(text)
206
+ text =self.sle(self.language,text)
207
+ with torch.no_grad():
208
+ stn_tst = self.get_text(text, self.hps_ms, cleaned=False)
209
+ x_tst = stn_tst.unsqueeze(0).to(self.dev)
210
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.dev)
211
+ sid = torch.LongTensor([self.speaker_id]).to(self.dev)
212
+ audio = self.net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=self.n_scale, noise_scale_w=self.n_scale_w, length_scale=self.l_scale)[0][
213
+ 0, 0].data.cpu().float().numpy()
214
+ resampled_audio_data = signal.resample(audio, len(audio) * 2)
215
+ sf.write('temp.wav', resampled_audio_data, 44100, 'PCM_24')
216
+ return (self.hps_ms.data.sampling_rate, audio),text.replace('[JA]','').replace('[ZH]','')
217
+
218
+ app = Flask(__name__)
219
+ print("开始部���")
220
+ grVits = VitsGradio()
221
+
222
+ @app.route('/chat')
223
+ def text_api():
224
+ message = request.args.get('Text','')
225
+ audio,text = grVits.tts_fn(message)
226
+ text = text.replace('[JA]','').replace('[ZH]','')
227
+ with open('temp.wav','rb') as bit:
228
+ wav_bytes = bit.read()
229
+ headers = {
230
+ 'Content-Type': 'audio/wav',
231
+ 'Text': text.encode('utf-8')}
232
+ return wav_bytes, 200, headers
233
+
234
+ def gradio_interface():
235
+ return grVits.Vits.launch()
236
+
237
+ if __name__ == '__main__':
238
+ api_thread = Thread(target=app.run, args=("0.0.0.0", 8080))
239
+ gradio_thread = Thread(target=gradio_interface)
240
+ api_thread.start()
241
+ gradio_thread.start()
部署流程.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ```sh
2
+ #前置条件 已安装Anaconda
3
+ conda create -n chatbot python=3.8
4
+ conda activate chatbot
5
+ git clone https://huggingface.co/spaces/Mahiruoshi/vits-chatbot
6
+ cd vits-chatbot
7
+ pip install -r requirements.txt
8
+ python main.py
9
+ ```