Spaces:
Runtime error
Runtime error
from web.api import api_blueprint | |
from pathlib import Path | |
from gevent import pywsgi as wsgi | |
from flask import Flask, Response, request, render_template | |
from synthesizer.inference import Synthesizer | |
from encoder import inference as encoder | |
from vocoder.hifigan import inference as gan_vocoder | |
from vocoder.wavernn import inference as rnn_vocoder | |
import numpy as np | |
import re | |
from scipy.io.wavfile import write | |
import librosa | |
import io | |
import base64 | |
from flask_cors import CORS | |
from flask_wtf import CSRFProtect | |
import webbrowser | |
def webApp(): | |
# Init and load config | |
app = Flask(__name__, instance_relative_config=True) | |
app.config.from_object("web.config.default") | |
app.config['RESTPLUS_MASK_SWAGGER'] = False | |
app.register_blueprint(api_blueprint) | |
# CORS(app) #允许跨域,注释掉此行则禁止跨域请求 | |
csrf = CSRFProtect(app) | |
csrf.init_app(app) | |
syn_models_dirt = "synthesizer/saved_models" | |
synthesizers = list(Path(syn_models_dirt).glob("**/*.pt")) | |
synthesizers_cache = {} | |
encoder.load_model(Path("encoder/saved_models/pretrained.pt")) | |
rnn_vocoder.load_model(Path("vocoder/saved_models/pretrained/pretrained.pt")) | |
gan_vocoder.load_model(Path("vocoder/saved_models/pretrained/g_hifigan.pt")) | |
def pcm2float(sig, dtype='float32'): | |
"""Convert PCM signal to floating point with a range from -1 to 1. | |
Use dtype='float32' for single precision. | |
Parameters | |
---------- | |
sig : array_like | |
Input array, must have integral type. | |
dtype : data type, optional | |
Desired (floating point) data type. | |
Returns | |
------- | |
numpy.ndarray | |
Normalized floating point data. | |
See Also | |
-------- | |
float2pcm, dtype | |
""" | |
sig = np.asarray(sig) | |
if sig.dtype.kind not in 'iu': | |
raise TypeError("'sig' must be an array of integers") | |
dtype = np.dtype(dtype) | |
if dtype.kind != 'f': | |
raise TypeError("'dtype' must be a floating point type") | |
i = np.iinfo(sig.dtype) | |
abs_max = 2 ** (i.bits - 1) | |
offset = i.min + abs_max | |
return (sig.astype(dtype) - offset) / abs_max | |
# Cache for synthesizer | |
def synthesize(): | |
# TODO Implementation with json to support more platform | |
# Load synthesizer | |
if "synt_path" in request.form: | |
synt_path = request.form["synt_path"] | |
else: | |
synt_path = synthesizers[0] | |
print("NO synthsizer is specified, try default first one.") | |
if synthesizers_cache.get(synt_path) is None: | |
current_synt = Synthesizer(Path(synt_path)) | |
synthesizers_cache[synt_path] = current_synt | |
else: | |
current_synt = synthesizers_cache[synt_path] | |
print("using synthesizer model: " + str(synt_path)) | |
# Load input wav | |
if "upfile_b64" in request.form: | |
wav_base64 = request.form["upfile_b64"] | |
wav = base64.b64decode(bytes(wav_base64, 'utf-8')) | |
wav = pcm2float(np.frombuffer(wav, dtype=np.int16), dtype=np.float32) | |
sample_rate = Synthesizer.sample_rate | |
else: | |
wav, sample_rate, = librosa.load(request.files['file']) | |
write("temp.wav", sample_rate, wav) #Make sure we get the correct wav | |
encoder_wav = encoder.preprocess_wav(wav, sample_rate) | |
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True) | |
# Load input text | |
texts = filter(None, request.form["text"].split("\n")) | |
punctuation = '!,。、,' # punctuate and split/clean text | |
processed_texts = [] | |
for text in texts: | |
for processed_text in re.sub(r'[{}]+'.format(punctuation), '\n', text).split('\n'): | |
if processed_text: | |
processed_texts.append(processed_text.strip()) | |
texts = processed_texts | |
# synthesize and vocode | |
embeds = [embed] * len(texts) | |
specs = current_synt.synthesize_spectrograms(texts, embeds) | |
spec = np.concatenate(specs, axis=1) | |
sample_rate = Synthesizer.sample_rate | |
if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN": | |
wav, sample_rate = rnn_vocoder.infer_waveform(spec) | |
else: | |
wav, sample_rate = gan_vocoder.infer_waveform(spec) | |
# Return cooked wav | |
out = io.BytesIO() | |
write(out, sample_rate, wav.astype(np.float32)) | |
return Response(out, mimetype="audio/wav") | |
def index(): | |
return render_template("index.html") | |
host = app.config.get("HOST") | |
port = app.config.get("PORT") | |
web_address = 'http://{}:{}'.format(host, port) | |
print(f"Web server:" + web_address) | |
webbrowser.open(web_address) | |
server = wsgi.WSGIServer((host, port), app) | |
server.serve_forever() | |
return app | |
if __name__ == "__main__": | |
webApp() | |