siddhartharya commited on
Commit
ba22d1b
1 Parent(s): 8412e92

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +16 -2
utils.py CHANGED
@@ -2,8 +2,12 @@ from groq import Groq
2
  from pydantic import BaseModel, ValidationError
3
  from typing import List, Literal
4
  import os
 
 
 
5
 
6
  groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
 
7
 
8
  class DialogueItem(BaseModel):
9
  speaker: Literal["Host", "Guest"]
@@ -12,6 +16,12 @@ class DialogueItem(BaseModel):
12
  class Dialogue(BaseModel):
13
  dialogue: List[DialogueItem]
14
 
 
 
 
 
 
 
15
  def generate_script(system_prompt: str, input_text: str, tone: str):
16
  input_text = truncate_text(input_text)
17
  prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
@@ -20,7 +30,7 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
20
  messages=[
21
  {"role": "system", "content": prompt},
22
  ],
23
- model="llama-3.1-70b-versatile", # Updated to the correct model name
24
  max_tokens=2048,
25
  temperature=0.7
26
  )
@@ -32,4 +42,8 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
32
 
33
  return dialogue
34
 
35
- # Make sure the truncate_text function is defined here or imported if it's in another file
 
 
 
 
 
2
  from pydantic import BaseModel, ValidationError
3
  from typing import List, Literal
4
  import os
5
+ import tiktoken
6
+ from gtts import gTTS
7
+ import tempfile
8
 
9
  groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
10
+ tokenizer = tiktoken.get_encoding("cl100k_base")
11
 
12
  class DialogueItem(BaseModel):
13
  speaker: Literal["Host", "Guest"]
 
16
  class Dialogue(BaseModel):
17
  dialogue: List[DialogueItem]
18
 
19
+ def truncate_text(text, max_tokens=2048):
20
+ tokens = tokenizer.encode(text)
21
+ if len(tokens) > max_tokens:
22
+ return tokenizer.decode(tokens[:max_tokens])
23
+ return text
24
+
25
  def generate_script(system_prompt: str, input_text: str, tone: str):
26
  input_text = truncate_text(input_text)
27
  prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
 
30
  messages=[
31
  {"role": "system", "content": prompt},
32
  ],
33
+ model="llama-3.1-70b-versatile",
34
  max_tokens=2048,
35
  temperature=0.7
36
  )
 
42
 
43
  return dialogue
44
 
45
+ def generate_audio(text: str, speaker: str) -> str:
46
+ tts = gTTS(text, lang='en', tld='com' if speaker == "Host" else 'co.uk')
47
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as temp_audio:
48
+ tts.save(temp_audio.name)
49
+ return temp_audio.name