siddhartharya commited on
Commit
8412e92
1 Parent(s): 297b79a

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +4 -17
utils.py CHANGED
@@ -1,12 +1,9 @@
1
- import os
2
  from groq import Groq
3
- from gtts import gTTS
4
  from pydantic import BaseModel, ValidationError
5
  from typing import List, Literal
6
- import tiktoken
7
 
8
- groq_client = Groq()
9
- tokenizer = tiktoken.get_encoding("cl100k_base")
10
 
11
  class DialogueItem(BaseModel):
12
  speaker: Literal["Host", "Guest"]
@@ -15,12 +12,6 @@ class DialogueItem(BaseModel):
15
  class Dialogue(BaseModel):
16
  dialogue: List[DialogueItem]
17
 
18
- def truncate_text(text, max_tokens=2048):
19
- tokens = tokenizer.encode(text)
20
- if len(tokens) > max_tokens:
21
- return tokenizer.decode(tokens[:max_tokens])
22
- return text
23
-
24
  def generate_script(system_prompt: str, input_text: str, tone: str):
25
  input_text = truncate_text(input_text)
26
  prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
@@ -29,7 +20,7 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
29
  messages=[
30
  {"role": "system", "content": prompt},
31
  ],
32
- model="llama2-70b-4096",
33
  max_tokens=2048,
34
  temperature=0.7
35
  )
@@ -41,8 +32,4 @@ def generate_script(system_prompt: str, input_text: str, tone: str):
41
 
42
  return dialogue
43
 
44
- def generate_audio(text: str, speaker: str) -> str:
45
- tts = gTTS(text, lang='en', tld='com' if speaker == "Host" else 'co.uk')
46
- filename = f"{speaker.lower()}_audio.mp3"
47
- tts.save(filename)
48
- return filename
 
 
1
  from groq import Groq
 
2
  from pydantic import BaseModel, ValidationError
3
  from typing import List, Literal
4
+ import os
5
 
6
+ groq_client = Groq(api_key=os.environ["GROQ_API_KEY"])
 
7
 
8
  class DialogueItem(BaseModel):
9
  speaker: Literal["Host", "Guest"]
 
12
  class Dialogue(BaseModel):
13
  dialogue: List[DialogueItem]
14
 
 
 
 
 
 
 
15
  def generate_script(system_prompt: str, input_text: str, tone: str):
16
  input_text = truncate_text(input_text)
17
  prompt = f"{system_prompt}\nTONE: {tone}\nINPUT TEXT: {input_text}"
 
20
  messages=[
21
  {"role": "system", "content": prompt},
22
  ],
23
+ model="llama-3.1-70b-versatile", # Updated to the correct model name
24
  max_tokens=2048,
25
  temperature=0.7
26
  )
 
32
 
33
  return dialogue
34
 
35
+ # Make sure the truncate_text function is defined here or imported if it's in another file