from groq import Groq from pydantic import BaseModel, ValidationError from typing import List, Literal import os import tiktoken import tempfile import json import re from transformers import pipeline import torch import soundfile as sf groq_client = Groq(api_key=os.environ["GROQ_API_KEY"]) tokenizer = tiktoken.get_encoding("cl100k_base") # Initialize TTS pipelines tts_male = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu") tts_female = pipeline("text-to-speech", model="microsoft/speecht5_tts", device="cpu") # Load speaker embeddings male_embedding = torch.load("https://huggingface.co./microsoft/speecht5_tts/resolve/main/en_speaker_1.pt") female_embedding = torch.load("https://huggingface.co./microsoft/speecht5_tts/resolve/main/en_speaker_9.pt") class DialogueItem(BaseModel): speaker: Literal["John", "Sarah"] text: str class Dialogue(BaseModel): dialogue: List[DialogueItem] def truncate_text(text, max_tokens=2048): tokens = tokenizer.encode(text) if len(tokens) > max_tokens: return tokenizer.decode(tokens[:max_tokens]) return text def generate_script(system_prompt: str, input_text: str, tone: str): input_text = truncate_text(input_text) prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}" response = groq_client.chat.completions.create( messages=[ {"role": "system", "content": prompt}, ], model="llama-3.1-70b-versatile", max_tokens=2048, temperature=0.7 ) content = response.choices[0].message.content content = re.sub(r'```json\s*|\s*```', '', content) try: json_data = json.loads(content) dialogue = Dialogue.model_validate(json_data) except json.JSONDecodeError as json_error: match = re.search(r'\{.*\}', content, re.DOTALL) if match: try: json_data = json.loads(match.group()) dialogue = Dialogue.model_validate(json_data) except (json.JSONDecodeError, ValidationError) as e: raise ValueError(f"Failed to parse dialogue JSON: {e}\nContent: {content}") else: raise ValueError(f"Failed to find valid JSON in the response: {content}") except ValidationError as e: raise ValueError(f"Failed to validate dialogue structure: {e}\nContent: {content}") return dialogue def generate_audio(text: str, speaker: str) -> str: if speaker == "John": speech = tts_male(text, speaker_embeddings=male_embedding) else: # Sarah speech = tts_female(text, speaker_embeddings=female_embedding) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: sf.write(temp_audio.name, speech["audio"], speech["sampling_rate"]) return temp_audio.name