File size: 1,266 Bytes
a45bb1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, set_seed
import soundfile as sf
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-expresso").to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")
def __call__(self, data: Any):
inputs = data["inputs"]
prompt = inputs["prompt"]
description = inputs["description"]
input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
prompt_input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
set_seed(42)
try:
generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
audio_arr = generation.cpu().numpy().squeeze()
return audio_arr
except Exception as e:
logger.error(str(e))
del inputs
gc.collect()
torch.cuda.empty_cache()
return {"error": str(e)} |