saq1b commited on
Commit
1491507
1 Parent(s): b4a62f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -19
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from pydub import AudioSegment
3
- from groq import AsyncGroq
 
4
  import json
5
  import uuid
6
  import io
@@ -12,10 +13,10 @@ import os
12
  from typing import List, Dict, Tuple
13
 
14
  class PodcastGenerator:
15
- def __init__(self, groq_api_key: str):
16
- self.groq_client = AsyncGroq(api_key=groq_api_key)
17
 
18
- async def generate_script(self, prompt: str, language: str) -> Dict:
19
  example = """
20
  {
21
  "topic": "AGI",
@@ -228,19 +229,32 @@ Follow this example structure:
228
  user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
229
 
230
  messages = [
231
- {"role": "system", "content": system_prompt},
232
- {"role": "user", "content": user_prompt}
233
  ]
234
 
235
- response = await self.groq_client.chat.completions.create(
236
- messages=messages,
237
- model="llama-3.1-70b-versatile",
238
- response_format={"type": "json_object"},
239
- max_tokens=4096,
240
- temperature=1,
 
 
 
 
 
 
 
 
 
 
 
 
241
  )
242
 
243
- return json.loads(response.choices[0].message.content)
 
 
244
 
245
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
246
  voice = speaker1 if speaker == 1 else speaker2
@@ -265,9 +279,11 @@ Follow this example structure:
265
  combined_audio.export(output_filename, format="wav")
266
  return output_filename
267
 
268
- async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str) -> str:
269
- podcast_json = await self.generate_script(input_text, language)
 
270
  print(f"Generated podcast script:\n{podcast_json}")
 
271
  audio_files = await asyncio.gather(*[self.tts_generate(item['line'], item['speaker'], speaker1, speaker2) for item in podcast_json['podcast']])
272
  combined_audio = await self.combine_audio_files(audio_files)
273
  return combined_audio
@@ -293,9 +309,10 @@ class TextExtractor:
293
  elif file_extension.lower() == '.txt':
294
  return await cls.extract_from_txt(file_path)
295
  else:
 
296
  raise ValueError(f"Unsupported file type: {file_extension}")
297
 
298
- async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str) -> str:
299
  voice_names = {
300
  "Andrew - English (United States)": "en-US-AndrewMultilingualNeural",
301
  "Ava - English (United States)": "en-US-AvaMultilingualNeural",
@@ -313,8 +330,11 @@ async def process_input(input_text: str, input_file, language: str, speaker1: st
313
  if input_file:
314
  input_text = await TextExtractor.extract_text(input_file.name)
315
 
316
- podcast_generator = PodcastGenerator(groq_api_key=os.environ["GROQ_API_KEY"])
317
- return await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2)
 
 
 
318
 
319
  # Define Gradio interface
320
  iface = gr.Interface(
@@ -360,7 +380,8 @@ iface = gr.Interface(
360
  "Remy - French (France)",
361
  "Vivienne - French (France)"
362
  ],
363
- value="Ava - English (United States)")
 
364
  ],
365
  outputs=[
366
  gr.Audio(label="Generated Podcast Audio")
 
1
  import gradio as gr
2
  from pydub import AudioSegment
3
+ import google.generativeai as genai
4
+ from google.generativeai.types import HarmCategory, HarmBlockThreshold
5
  import json
6
  import uuid
7
  import io
 
13
  from typing import List, Dict, Tuple
14
 
15
  class PodcastGenerator:
16
+ def __init__(self):
17
+ pass
18
 
19
+ async def generate_script(self, prompt: str, language: str, api_key: str) -> Dict:
20
  example = """
21
  {
22
  "topic": "AGI",
 
229
  user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
230
 
231
  messages = [
232
+ {"role": "user", "parts": [user_prompt]}
 
233
  ]
234
 
235
+ genai.configure(api_key=api_key)
236
+
237
+ generation_config = {
238
+ "temperature": 1,
239
+ "max_output_tokens": 8192,
240
+ "response_mime_type": "application/json",
241
+ }
242
+
243
+ model = genai.GenerativeModel(
244
+ model_name="gemini-1.5-flash",
245
+ generation_config=generation_config,
246
+ safety_settings={
247
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
248
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
249
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
250
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE
251
+ },
252
+ system_instruction=system_prompt
253
  )
254
 
255
+ response = await model.generate_content_async(messages)
256
+
257
+ return json.loads(response.text)
258
 
259
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
260
  voice = speaker1 if speaker == 1 else speaker2
 
279
  combined_audio.export(output_filename, format="wav")
280
  return output_filename
281
 
282
+ async def generate_podcast(self, input_text: str, language: str, speaker1: str, speaker2: str, api_key: str) -> str:
283
+ gr.Info("Generating podcast script...")
284
+ podcast_json = await self.generate_script(input_text, language, api_key)
285
  print(f"Generated podcast script:\n{podcast_json}")
286
+ gr.Info("Generating podcast audio files...")
287
  audio_files = await asyncio.gather(*[self.tts_generate(item['line'], item['speaker'], speaker1, speaker2) for item in podcast_json['podcast']])
288
  combined_audio = await self.combine_audio_files(audio_files)
289
  return combined_audio
 
309
  elif file_extension.lower() == '.txt':
310
  return await cls.extract_from_txt(file_path)
311
  else:
312
+ gr.Error(f"Unsupported file type: {file_extension}")
313
  raise ValueError(f"Unsupported file type: {file_extension}")
314
 
315
+ async def process_input(input_text: str, input_file, language: str, speaker1: str, speaker2: str, api_key: str = "") -> str:
316
  voice_names = {
317
  "Andrew - English (United States)": "en-US-AndrewMultilingualNeural",
318
  "Ava - English (United States)": "en-US-AvaMultilingualNeural",
 
330
  if input_file:
331
  input_text = await TextExtractor.extract_text(input_file.name)
332
 
333
+ if not api_key:
334
+ api_key = os.getenv("GENAI_API_KEY")
335
+
336
+ podcast_generator = PodcastGenerator()
337
+ return await podcast_generator.generate_podcast(input_text, language, speaker1, speaker2, api_key)
338
 
339
  # Define Gradio interface
340
  iface = gr.Interface(
 
380
  "Remy - French (France)",
381
  "Vivienne - French (France)"
382
  ],
383
+ value="Ava - English (United States)"),
384
+ gr.Textbox(label="Gemini API Key (Optional)", type="password"),
385
  ],
386
  outputs=[
387
  gr.Audio(label="Generated Podcast Audio")