|
|
|
import argparse |
|
import contextlib |
|
import gc |
|
import os |
|
import queue |
|
import re |
|
import subprocess |
|
import sys |
|
import threading |
|
import time |
|
import yaml |
|
import json |
|
|
|
from fastapi.responses import StreamingResponse |
|
from loguru import logger |
|
from openedai import OpenAIStub, BadRequestError, ServiceUnavailableError |
|
from pydantic import BaseModel |
|
import uvicorn |
|
|
|
@contextlib.asynccontextmanager |
|
async def lifespan(app): |
|
yield |
|
gc.collect() |
|
try: |
|
import torch |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
except: |
|
pass |
|
|
|
app = OpenAIStub(lifespan=lifespan) |
|
xtts = None |
|
args = None |
|
|
|
def unload_model(): |
|
import torch, gc |
|
global xtts |
|
if xtts: |
|
logger.info("Unloading model") |
|
xtts.xtts.to('cpu') |
|
del xtts |
|
xtts = None |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
torch.cuda.ipc_collect() |
|
|
|
class xtts_wrapper(): |
|
check_interval: int = 1 |
|
|
|
def __init__(self, model_name, device, model_path=None, unload_timer=None): |
|
self.model_name = model_name |
|
self.unload_timer = unload_timer |
|
self.last_used = time.time() |
|
self.timer = None |
|
self.lock = threading.Lock() |
|
|
|
logger.info(f"Loading model {self.model_name} to {device}") |
|
|
|
if model_path is None: |
|
model_path = ModelManager().download_model(model_name)[0] |
|
|
|
config_path = os.path.join(model_path, 'config.json') |
|
config = XttsConfig() |
|
config.load_json(config_path) |
|
self.xtts = Xtts.init_from_config(config) |
|
self.xtts.load_checkpoint(config, checkpoint_dir=model_path, use_deepspeed=args.use_deepspeed) |
|
self.xtts = self.xtts.to(device=device) |
|
self.xtts.eval() |
|
|
|
if self.unload_timer: |
|
logger.info(f"Setting unload timer to {self.unload_timer} seconds") |
|
self.last_used = time.time() |
|
self.check_idle() |
|
|
|
def check_idle(self): |
|
with self.lock: |
|
if time.time() - self.last_used >= self.unload_timer: |
|
print("Unloading TTS model due to inactivity") |
|
unload_model() |
|
else: |
|
|
|
self.timer = threading.Timer(self.check_interval, self.check_idle) |
|
self.timer.daemon = True |
|
self.timer.start() |
|
|
|
def tts(self, text, language, audio_path, **hf_generate_kwargs): |
|
with torch.no_grad(): |
|
self.last_used = time.time() |
|
tokens = 0 |
|
try: |
|
with self.lock: |
|
logger.debug(f"generating [{language}]: {[text]}") |
|
|
|
gpt_cond_latent, speaker_embedding = self.xtts.get_conditioning_latents(audio_path=audio_path) |
|
pcm_stream = self.xtts.inference_stream(text, language, gpt_cond_latent, speaker_embedding, **hf_generate_kwargs) |
|
self.last_used = time.time() |
|
|
|
while True: |
|
with self.lock: |
|
yield next(pcm_stream).cpu().numpy().tobytes() |
|
self.last_used = time.time() |
|
tokens += 1 |
|
|
|
except StopIteration: |
|
pass |
|
|
|
finally: |
|
logger.debug(f"Generated {tokens} tokens in {time.time() - self.last_used:.2f}s @ {tokens / (time.time() - self.last_used):.2f} T/s") |
|
self.last_used = time.time() |
|
|
|
def default_exists(filename: str): |
|
if not os.path.exists(filename): |
|
fpath, ext = os.path.splitext(filename) |
|
basename = os.path.basename(fpath) |
|
default = f"{basename}.default{ext}" |
|
|
|
logger.info(f"{filename} does not exist, setting defaults from {default}") |
|
|
|
with open(default, 'r', encoding='utf8') as from_file: |
|
with open(filename, 'w', encoding='utf8') as to_file: |
|
to_file.write(from_file.read()) |
|
|
|
|
|
def preprocess(raw_input): |
|
|
|
default_exists('config/pre_process_map.yaml') |
|
with open('config/pre_process_map.yaml', 'r', encoding='utf8') as file: |
|
pre_process_map = yaml.safe_load(file) |
|
for a, b in pre_process_map: |
|
raw_input = re.sub(a, b, raw_input) |
|
|
|
raw_input = raw_input.strip() |
|
|
|
return raw_input |
|
|
|
|
|
def map_voice_to_speaker(voice: str, model: str): |
|
default_exists('config/voice_to_speaker.yaml') |
|
with open('config/voice_to_speaker.yaml', 'r', encoding='utf8') as file: |
|
voice_map = yaml.safe_load(file) |
|
try: |
|
return voice_map[model][voice] |
|
|
|
except KeyError as e: |
|
raise BadRequestError(f"Error loading voice: {voice}, KeyError: {e}", param='voice') |
|
|
|
class GenerateSpeechRequest(BaseModel): |
|
model: str = "tts-1" |
|
input: str |
|
voice: str = "alloy" |
|
response_format: str = "mp3" |
|
speed: float = 1.0 |
|
|
|
def build_ffmpeg_args(response_format, input_format, sample_rate): |
|
|
|
if input_format == 'WAV': |
|
ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", "WAV", "-i", "-"] |
|
else: |
|
ffmpeg_args = ["ffmpeg", "-loglevel", "error", "-f", input_format, "-ar", sample_rate, "-ac", "1", "-i", "-"] |
|
|
|
if response_format == "mp3": |
|
ffmpeg_args.extend(["-f", "mp3", "-c:a", "libmp3lame", "-ab", "64k"]) |
|
elif response_format == "opus": |
|
ffmpeg_args.extend(["-f", "ogg", "-c:a", "libopus"]) |
|
elif response_format == "aac": |
|
ffmpeg_args.extend(["-f", "adts", "-c:a", "aac", "-ab", "64k"]) |
|
elif response_format == "flac": |
|
ffmpeg_args.extend(["-f", "flac", "-c:a", "flac"]) |
|
elif response_format == "wav": |
|
ffmpeg_args.extend(["-f", "wav", "-c:a", "pcm_s16le"]) |
|
elif response_format == "pcm": |
|
ffmpeg_args.extend(["-f", "s16le", "-c:a", "pcm_s16le"]) |
|
|
|
return ffmpeg_args |
|
|
|
@app.post("/v1/audio/speech", response_class=StreamingResponse) |
|
async def generate_speech(request: GenerateSpeechRequest): |
|
global xtts, args |
|
if len(request.input) < 1: |
|
raise BadRequestError("Empty Input", param='input') |
|
|
|
input_text = preprocess(request.input) |
|
|
|
if len(input_text) < 1: |
|
raise BadRequestError("Input text empty after preprocess.", param='input') |
|
|
|
model = request.model |
|
voice = request.voice |
|
response_format = request.response_format.lower() |
|
speed = request.speed |
|
|
|
|
|
if response_format == "mp3": |
|
media_type = "audio/mpeg" |
|
elif response_format == "opus": |
|
media_type = "audio/ogg;codec=opus" |
|
elif response_format == "aac": |
|
media_type = "audio/aac" |
|
elif response_format == "flac": |
|
media_type = "audio/x-flac" |
|
elif response_format == "wav": |
|
media_type = "audio/wav" |
|
elif response_format == "pcm": |
|
if model == 'tts-1': |
|
media_type = "audio/pcm;rate=22050" |
|
elif model == 'tts-1-hd': |
|
media_type = "audio/pcm;rate=24000" |
|
else: |
|
raise BadRequestError(f"Invalid response_format: '{response_format}'", param='response_format') |
|
|
|
ffmpeg_args = None |
|
|
|
|
|
if model == 'tts-1' or args.xtts_device == 'none': |
|
voice_map = map_voice_to_speaker(voice, 'tts-1') |
|
try: |
|
piper_model = voice_map['model'] |
|
|
|
except KeyError as e: |
|
raise ServiceUnavailableError(f"Configuration error: tts-1 voice '{voice}' is missing 'model:' setting. KeyError: {e}") |
|
|
|
speaker = voice_map.get('speaker', None) |
|
|
|
tts_args = ["piper", "--model", str(piper_model), "--data-dir", "voices", "--download-dir", "voices", "--output-raw"] |
|
if speaker: |
|
tts_args.extend(["--speaker", str(speaker)]) |
|
if speed != 1.0: |
|
tts_args.extend(["--length-scale", f"{1.0/speed}"]) |
|
|
|
tts_proc = subprocess.Popen(tts_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) |
|
tts_proc.stdin.write(bytearray(input_text.encode('utf-8'))) |
|
tts_proc.stdin.close() |
|
|
|
try: |
|
with open(f"{piper_model}.json", 'r') as pvc_f: |
|
conf = json.load(pvc_f) |
|
sample_rate = str(conf['audio']['sample_rate']) |
|
|
|
except: |
|
sample_rate = '22050' |
|
|
|
ffmpeg_args = build_ffmpeg_args(response_format, input_format="s16le", sample_rate=sample_rate) |
|
|
|
|
|
ffmpeg_args.extend(["-"]) |
|
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=tts_proc.stdout, stdout=subprocess.PIPE) |
|
|
|
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type) |
|
|
|
elif model == 'tts-1-hd': |
|
voice_map = map_voice_to_speaker(voice, 'tts-1-hd') |
|
try: |
|
tts_model = voice_map.pop('model') |
|
speaker = voice_map.pop('speaker') |
|
|
|
except KeyError as e: |
|
raise ServiceUnavailableError(f"Configuration error: tts-1-hd voice '{voice}' is missing setting. KeyError: {e}") |
|
|
|
if xtts and xtts.model_name != tts_model: |
|
unload_model() |
|
|
|
tts_model_path = voice_map.pop('model_path', None) |
|
|
|
if xtts is None: |
|
xtts = xtts_wrapper(tts_model, device=args.xtts_device, model_path=tts_model_path, unload_timer=args.unload_timer) |
|
|
|
ffmpeg_args = build_ffmpeg_args(response_format, input_format="f32le", sample_rate="24000") |
|
|
|
|
|
speed = voice_map.pop('speed', speed) |
|
if speed < 0.5: |
|
speed = speed / 0.5 |
|
ffmpeg_args.extend(["-af", "atempo=0.5"]) |
|
if speed > 1.0: |
|
ffmpeg_args.extend(["-af", f"atempo={speed}"]) |
|
speed = 1.0 |
|
|
|
|
|
ffmpeg_args.extend(["-"]) |
|
|
|
language = voice_map.pop('language', 'auto') |
|
if language == 'auto': |
|
try: |
|
language = detect(input_text) |
|
if language not in [ |
|
'en', 'es', 'fr', 'de', 'it', 'pt', 'pl', 'tr', |
|
'ru', 'nl', 'cs', 'ar', 'zh-cn', 'hu', 'ko', 'ja', 'hi' |
|
]: |
|
logger.debug(f"Detected language {language} not supported, defaulting to en") |
|
language = 'en' |
|
else: |
|
logger.debug(f"Detected language: {language}") |
|
except: |
|
language = 'en' |
|
logger.debug(f"Failed to detect language, defaulting to en") |
|
|
|
comment = voice_map.pop('comment', None) |
|
|
|
hf_generate_kwargs = dict( |
|
speed=speed, |
|
**voice_map, |
|
) |
|
|
|
hf_generate_kwargs['enable_text_splitting'] = hf_generate_kwargs.get('enable_text_splitting', True) |
|
|
|
if hf_generate_kwargs['enable_text_splitting']: |
|
if language == 'zh-cn': |
|
split_lang = 'zh' |
|
else: |
|
split_lang = language |
|
all_text = split_sentence(input_text, split_lang, xtts.xtts.tokenizer.char_limits[split_lang]) |
|
else: |
|
all_text = [input_text] |
|
|
|
ffmpeg_proc = subprocess.Popen(ffmpeg_args, stdin=subprocess.PIPE, stdout=subprocess.PIPE) |
|
|
|
in_q = queue.Queue() |
|
ex_q = queue.Queue() |
|
|
|
def get_speaker_samples(samples: str) -> list[str]: |
|
if os.path.isfile(samples): |
|
audio_path = [samples] |
|
elif os.path.isdir(samples): |
|
audio_path = [os.path.join(samples, sample) for sample in os.listdir(samples) if os.path.isfile(os.path.join(samples, sample))] |
|
|
|
if len(audio_path) < 1: |
|
logger.error(f"No files found: {samples}") |
|
raise ServiceUnavailableError(f"Invalid path: {samples}") |
|
else: |
|
logger.error(f"Invalid path: {samples}") |
|
raise ServiceUnavailableError(f"Invalid path: {samples}") |
|
|
|
return audio_path |
|
|
|
def exception_check(exq: queue.Queue): |
|
try: |
|
e = exq.get_nowait() |
|
except queue.Empty: |
|
return |
|
|
|
raise e |
|
|
|
def generator(): |
|
|
|
|
|
audio_path = get_speaker_samples(speaker) |
|
logger.debug(f"{voice} wav samples: {audio_path}") |
|
|
|
try: |
|
for text in all_text: |
|
for chunk in xtts.tts(text=text, language=language, audio_path=audio_path, **hf_generate_kwargs): |
|
exception_check(ex_q) |
|
in_q.put(chunk) |
|
|
|
except BrokenPipeError as e: |
|
logger.info("Client disconnected - 'Broken pipe'") |
|
|
|
except Exception as e: |
|
logger.error(f"Exception: {repr(e)}") |
|
raise e |
|
|
|
finally: |
|
in_q.put(None) |
|
|
|
def out_writer(): |
|
|
|
try: |
|
while True: |
|
chunk = in_q.get() |
|
if chunk is None: |
|
break |
|
ffmpeg_proc.stdin.write(chunk) |
|
|
|
except Exception as e: |
|
ex_q.put(e) |
|
ffmpeg_proc.kill() |
|
return |
|
|
|
finally: |
|
ffmpeg_proc.stdin.close() |
|
|
|
generator_worker = threading.Thread(target=generator, daemon=True) |
|
generator_worker.start() |
|
|
|
out_writer_worker = threading.Thread(target=out_writer, daemon=True) |
|
out_writer_worker.start() |
|
|
|
def cleanup(): |
|
ffmpeg_proc.kill() |
|
del generator_worker |
|
del out_writer_worker |
|
|
|
return StreamingResponse(content=ffmpeg_proc.stdout, media_type=media_type, background=cleanup) |
|
else: |
|
raise BadRequestError("No such model, must be tts-1 or tts-1-hd.", param='model') |
|
|
|
|
|
|
|
def auto_torch_device(): |
|
try: |
|
import torch |
|
return 'cuda' if torch.cuda.is_available() else 'mps' if ( torch.backends.mps.is_available() and torch.backends.mps.is_built() ) else 'cpu' |
|
|
|
except: |
|
return 'none' |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser( |
|
description='OpenedAI Speech API Server', |
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
|
parser.add_argument('--xtts_device', action='store', default=auto_torch_device(), help="Set the device for the xtts model. The special value of 'none' will use piper for all models.") |
|
parser.add_argument('--preload', action='store', default=None, help="Preload a model (Ex. 'xtts' or 'xtts_v2.0.2'). By default it's loaded on first use.") |
|
parser.add_argument('--unload-timer', action='store', default=None, type=int, help="Idle unload timer for the XTTS model in seconds, Ex. 900 for 15 minutes") |
|
parser.add_argument('--use-deepspeed', action='store_true', default=False, help="Use deepspeed with xtts (this option is unsupported)") |
|
parser.add_argument('--no-cache-speaker', action='store_true', default=False, help="Don't use the speaker wav embeddings cache") |
|
parser.add_argument('-P', '--port', action='store', default=8000, type=int, help="Server tcp port") |
|
parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. 0.0.0.0") |
|
parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level") |
|
|
|
args = parser.parse_args() |
|
|
|
default_exists('config/pre_process_map.yaml') |
|
default_exists('config/voice_to_speaker.yaml') |
|
|
|
logger.remove() |
|
logger.add(sink=sys.stderr, level=args.log_level) |
|
|
|
if args.xtts_device != "none": |
|
import torch |
|
from TTS.tts.configs.xtts_config import XttsConfig |
|
from TTS.tts.models.xtts import Xtts |
|
from TTS.utils.manage import ModelManager |
|
from TTS.tts.layers.xtts.tokenizer import split_sentence |
|
from langdetect import detect |
|
|
|
if args.preload: |
|
xtts = xtts_wrapper(args.preload, device=args.xtts_device, unload_timer=args.unload_timer) |
|
|
|
app.register_model('tts-1') |
|
app.register_model('tts-1-hd') |
|
|
|
uvicorn.run(app, host=args.host, port=args.port) |
|
|