File size: 1,266 Bytes
a45bb1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, set_seed
import soundfile as sf


class EndpointHandler:
    def __init__(self, path=""):
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"

        self.model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-expresso").to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-expresso")

    def __call__(self, data: Any):
        inputs = data["inputs"]
        prompt = inputs["prompt"]
        description = inputs["description"]
        
        input_ids = self.tokenizer(description, return_tensors="pt").input_ids.to(self.device)
        prompt_input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        
        set_seed(42)
        try:
            generation = self.model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
            audio_arr = generation.cpu().numpy().squeeze()
            return audio_arr
        
        except Exception as e:
            logger.error(str(e))
            del inputs
            gc.collect()
            torch.cuda.empty_cache()

            return {"error": str(e)}