File size: 2,835 Bytes
652d9d0
 
 
8412e92
ba22d1b
 
6c57b19
 
08f2510
 
 
652d9d0
8412e92
ba22d1b
652d9d0
08f2510
 
 
 
 
 
 
 
652d9d0
08f2510
652d9d0
 
 
 
 
ba22d1b
 
 
 
 
 
652d9d0
 
 
 
 
 
 
 
ba22d1b
652d9d0
 
 
 
6c57b19
 
 
652d9d0
6c57b19
 
 
 
 
 
 
 
 
 
 
 
652d9d0
6c57b19
652d9d0
 
 
ba22d1b
08f2510
 
 
 
 
 
 
ba22d1b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from groq import Groq
from pydantic import BaseModel, ValidationError
from typing import List, Literal
import os
import tiktoken
import tempfile
import json
import re
from transformers import pipeline
import torch
import soundfile as sf

groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
tokenizer = tiktoken.get_encoding("cl100k_base")

# Initialize TTS pipelines
tts_male = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu")
tts_female = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu")

# Load speaker embeddings
male_embedding = torch.load("https://huggingface.co./microsoft/speecht5_tts/resolve/main/en_speaker_1.pt")
female_embedding = torch.load("https://huggingface.co./microsoft/speecht5_tts/resolve/main/en_speaker_9.pt")

class DialogueItem(BaseModel):
    speaker: Literal["John", "Sarah"]
    text: str

class Dialogue(BaseModel):
    dialogue: List[DialogueItem]

def truncate_text(text, max_tokens=2048):
    tokens = tokenizer.encode(text)
    if len(tokens) > max_tokens:
        return tokenizer.decode(tokens[:max_tokens])
    return text

def generate_script(system_prompt: str, input_text: str, tone: str):
    input_text = truncate_text(input_text)
    prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
    
    response = groq_client.chat.completions.create(
        messages=[
            {"role": "system", "content": prompt},
        ],
        model="llama-3.1-70b-versatile",
        max_tokens=2048,
        temperature=0.7
    )
    
    content = response.choices[0].message.content
    content = re.sub(r'```json\s*|\s*```', '', content)
    
    try:
        json_data = json.loads(content)
        dialogue = Dialogue.model_validate(json_data)
    except json.JSONDecodeError as json_error:
        match = re.search(r'\{.*\}', content, re.DOTALL)
        if match:
            try:
                json_data = json.loads(match.group())
                dialogue = Dialogue.model_validate(json_data)
            except (json.JSONDecodeError, ValidationError) as e:
                raise ValueError(f"Failed to parse dialogue JSON: {e}\nContent: {content}")
        else:
            raise ValueError(f"Failed to find valid JSON in the response: {content}")
    except ValidationError as e:
        raise ValueError(f"Failed to validate dialogue structure: {e}\nContent: {content}")
    
    return dialogue

def generate_audio(text: str, speaker: str) -> str:
    if speaker == "John":
        speech = tts_male(text, speaker_embeddings=male_embedding)
    else:  # Sarah
        speech = tts_female(text, speaker_embeddings=female_embedding)
    
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
        sf.write(temp_audio.name, speech["audio"], speech["sampling_rate"])
        return temp_audio.name