Bils commited on
Commit
eaef5b0
Β·
verified Β·
1 Parent(s): a6afe59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -32
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # import os
2
  import re
3
  import torch
4
  import tempfile
@@ -39,7 +39,6 @@ def clean_text(text: str) -> str:
39
  """
40
  Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
41
  """
42
- # Remove all asterisks. You can add more cleaning steps here as needed.
43
  return re.sub(r'\*', '', text)
44
 
45
  # ---------------------------------------------------------------------
@@ -64,7 +63,6 @@ def get_llama_pipeline(model_id: str, token: str):
64
  LLAMA_PIPELINES[model_id] = text_pipeline
65
  return text_pipeline
66
 
67
-
68
  def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
69
  """
70
  Returns a cached MusicGen model if available; otherwise, loads it.
@@ -81,7 +79,6 @@ def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
81
  MUSICGEN_MODELS[model_key] = (model, processor)
82
  return model, processor
83
 
84
-
85
  def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
86
  """
87
  Returns a cached TTS model if available; otherwise, loads it.
@@ -93,7 +90,6 @@ def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
93
  TTS_MODELS[model_name] = tts_model
94
  return tts_model
95
 
96
-
97
  # ---------------------------------------------------------------------
98
  # Script Generation Function
99
  # ---------------------------------------------------------------------
@@ -105,7 +101,6 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
105
  """
106
  try:
107
  text_pipeline = get_llama_pipeline(model_id, token)
108
-
109
  system_prompt = (
110
  "You are an expert radio imaging producer specializing in sound design and music. "
111
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
@@ -132,7 +127,7 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
132
  sound_design = "No sound design suggestions found."
133
  music_suggestions = "No music suggestions found."
134
 
135
- # Voice-Over Script
136
  if "Voice-Over Script:" in generated_text:
137
  parts = generated_text.split("Voice-Over Script:")
138
  voice_script_part = parts[1]
@@ -141,7 +136,7 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
141
  else:
142
  voice_script = voice_script_part.strip()
143
 
144
- # Sound Design
145
  if "Sound Design Suggestions:" in generated_text:
146
  parts = generated_text.split("Sound Design Suggestions:")
147
  sound_design_part = parts[1]
@@ -150,7 +145,7 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
150
  else:
151
  sound_design = sound_design_part.strip()
152
 
153
- # Music Suggestions
154
  if "Music Suggestions:" in generated_text:
155
  parts = generated_text.split("Music Suggestions:")
156
  music_suggestions = parts[1].strip()
@@ -160,7 +155,6 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
160
  except Exception as e:
161
  return f"Error generating script: {e}", "", ""
162
 
163
-
164
  # ---------------------------------------------------------------------
165
  # Voice-Over Generation Function
166
  # ---------------------------------------------------------------------
@@ -174,12 +168,8 @@ def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/ta
174
  if not script.strip():
175
  return "Error: No script provided."
176
 
177
- # Clean the script to remove special characters (e.g., asterisks) that may produce warnings
178
  cleaned_script = clean_text(script)
179
-
180
  tts_model = get_tts_model(tts_model_name)
181
-
182
- # Generate and save voice
183
  output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
184
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
185
  return output_path
@@ -187,7 +177,6 @@ def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/ta
187
  except Exception as e:
188
  return f"Error generating voice: {e}"
189
 
190
-
191
  # ---------------------------------------------------------------------
192
  # Music Generation Function
193
  # ---------------------------------------------------------------------
@@ -205,23 +194,22 @@ def generate_music(prompt: str, audio_length: int):
205
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
206
 
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
- inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
 
 
209
 
210
  with torch.inference_mode():
211
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
212
 
213
  audio_data = outputs[0, 0].cpu().numpy()
214
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
215
-
216
  output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
217
  write(output_path, 44100, normalized_audio)
218
-
219
  return output_path
220
 
221
  except Exception as e:
222
  return f"Error generating music: {e}"
223
 
224
-
225
  # ---------------------------------------------------------------------
226
  # Audio Blending with Duration Sync & Ducking
227
  # ---------------------------------------------------------------------
@@ -241,17 +229,15 @@ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int
241
  voice = AudioSegment.from_wav(voice_path)
242
  music = AudioSegment.from_wav(music_path)
243
 
244
- voice_len = len(voice) # in milliseconds
245
- music_len = len(music) # in milliseconds
246
 
247
- # Loop music if it's shorter than the voice
248
  if music_len < voice_len:
249
  looped_music = AudioSegment.empty()
250
  while len(looped_music) < voice_len:
251
  looped_music += music
252
  music = looped_music
253
 
254
- # Trim music if it's longer than the voice
255
  if len(music) > voice_len:
256
  music = music[:voice_len]
257
 
@@ -268,6 +254,28 @@ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int
268
  except Exception as e:
269
  return f"Error blending audio: {e}"
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  # ---------------------------------------------------------------------
273
  # Gradio Interface with Enhanced UI
@@ -314,12 +322,12 @@ with gr.Blocks(css="""
314
  # Custom Header
315
  with gr.Row(elem_classes="header"):
316
  gr.Markdown("""
317
- <h1>🎧 AI Ads Promo</h1>
318
- <p>Your all-in-one AI solution for crafting engaging audio ads. <br><em>Demo MVP</em></p>
319
  """)
320
 
321
  gr.Markdown("""
322
- Welcome to **AI Ads Promo (Demo MVP)**! This platform leverages state-of-the-art AI models to help you generate:
323
 
324
  - **Script**: Generate a compelling voice-over script with LLaMA.
325
  - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS.
@@ -328,7 +336,7 @@ with gr.Blocks(css="""
328
  """)
329
 
330
  with gr.Tabs():
331
- # Step 1: Generate Script
332
  with gr.Tab("πŸ“ Script Generation"):
333
  with gr.Row():
334
  user_prompt = gr.Textbox(
@@ -360,7 +368,7 @@ with gr.Blocks(css="""
360
  outputs=[script_output, sound_design_output, music_suggestion_output],
361
  )
362
 
363
- # Step 2: Generate Voice
364
  with gr.Tab("🎀 Voice Synthesis"):
365
  gr.Markdown("Generate a natural-sounding voice-over using Coqui TTS.")
366
  selected_tts_model = gr.Dropdown(
@@ -382,7 +390,7 @@ with gr.Blocks(css="""
382
  outputs=voice_audio_output,
383
  )
384
 
385
- # Step 3: Generate Music
386
  with gr.Tab("🎢 Music Production"):
387
  gr.Markdown("Generate a custom music track using the **MusicGen Large** model.")
388
  audio_length = gr.Slider(
@@ -402,7 +410,7 @@ with gr.Blocks(css="""
402
  outputs=[music_output],
403
  )
404
 
405
- # Step 4: Blend Audio
406
  with gr.Tab("🎚️ Audio Blending"):
407
  gr.Markdown("Blend your voice-over and music track. Music will be looped/truncated to match the voice duration. Enable ducking to lower the music during voice segments.")
408
  ducking_checkbox = gr.Checkbox(label="Enable Ducking?", value=True)
@@ -422,17 +430,70 @@ with gr.Blocks(css="""
422
  outputs=blended_output
423
  )
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  # Footer
426
  gr.Markdown("""
427
  <div class="footer">
428
  <hr>
429
  Created with ❀️ by <a href="https://bilsimaging.com" target="_blank" style="color: #88aaff;">bilsimaging.com</a>
430
  <br>
431
- <small>AI Ads Promo (Demo MVP) &copy; 2025</small>
432
  </div>
433
  """)
434
 
435
- # Visitor Badge
436
  gr.HTML("""
437
  <div style="text-align: center; margin-top: 1rem;">
438
  <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
 
1
+ import os
2
  import re
3
  import torch
4
  import tempfile
 
39
  """
40
  Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
41
  """
 
42
  return re.sub(r'\*', '', text)
43
 
44
  # ---------------------------------------------------------------------
 
63
  LLAMA_PIPELINES[model_id] = text_pipeline
64
  return text_pipeline
65
 
 
66
  def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
67
  """
68
  Returns a cached MusicGen model if available; otherwise, loads it.
 
79
  MUSICGEN_MODELS[model_key] = (model, processor)
80
  return model, processor
81
 
 
82
  def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
83
  """
84
  Returns a cached TTS model if available; otherwise, loads it.
 
90
  TTS_MODELS[model_name] = tts_model
91
  return tts_model
92
 
 
93
  # ---------------------------------------------------------------------
94
  # Script Generation Function
95
  # ---------------------------------------------------------------------
 
101
  """
102
  try:
103
  text_pipeline = get_llama_pipeline(model_id, token)
 
104
  system_prompt = (
105
  "You are an expert radio imaging producer specializing in sound design and music. "
106
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
 
127
  sound_design = "No sound design suggestions found."
128
  music_suggestions = "No music suggestions found."
129
 
130
+ # Extract Voice-Over Script
131
  if "Voice-Over Script:" in generated_text:
132
  parts = generated_text.split("Voice-Over Script:")
133
  voice_script_part = parts[1]
 
136
  else:
137
  voice_script = voice_script_part.strip()
138
 
139
+ # Extract Sound Design Suggestions
140
  if "Sound Design Suggestions:" in generated_text:
141
  parts = generated_text.split("Sound Design Suggestions:")
142
  sound_design_part = parts[1]
 
145
  else:
146
  sound_design = sound_design_part.strip()
147
 
148
+ # Extract Music Suggestions
149
  if "Music Suggestions:" in generated_text:
150
  parts = generated_text.split("Music Suggestions:")
151
  music_suggestions = parts[1].strip()
 
155
  except Exception as e:
156
  return f"Error generating script: {e}", "", ""
157
 
 
158
  # ---------------------------------------------------------------------
159
  # Voice-Over Generation Function
160
  # ---------------------------------------------------------------------
 
168
  if not script.strip():
169
  return "Error: No script provided."
170
 
 
171
  cleaned_script = clean_text(script)
 
172
  tts_model = get_tts_model(tts_model_name)
 
 
173
  output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
174
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
175
  return output_path
 
177
  except Exception as e:
178
  return f"Error generating voice: {e}"
179
 
 
180
  # ---------------------------------------------------------------------
181
  # Music Generation Function
182
  # ---------------------------------------------------------------------
 
194
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
195
 
196
  device = "cuda" if torch.cuda.is_available() else "cpu"
197
+ # Process the input and move each tensor to the proper device
198
+ inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
199
+ inputs = {k: v.to(device) for k, v in inputs.items()}
200
 
201
  with torch.inference_mode():
202
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
203
 
204
  audio_data = outputs[0, 0].cpu().numpy()
205
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
 
206
  output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
207
  write(output_path, 44100, normalized_audio)
 
208
  return output_path
209
 
210
  except Exception as e:
211
  return f"Error generating music: {e}"
212
 
 
213
  # ---------------------------------------------------------------------
214
  # Audio Blending with Duration Sync & Ducking
215
  # ---------------------------------------------------------------------
 
229
  voice = AudioSegment.from_wav(voice_path)
230
  music = AudioSegment.from_wav(music_path)
231
 
232
+ voice_len = len(voice)
233
+ music_len = len(music)
234
 
 
235
  if music_len < voice_len:
236
  looped_music = AudioSegment.empty()
237
  while len(looped_music) < voice_len:
238
  looped_music += music
239
  music = looped_music
240
 
 
241
  if len(music) > voice_len:
242
  music = music[:voice_len]
243
 
 
254
  except Exception as e:
255
  return f"Error blending audio: {e}"
256
 
257
+ # ---------------------------------------------------------------------
258
+ # Agent Function: Orchestrate the Full Workflow
259
+ # ---------------------------------------------------------------------
260
+ @spaces.GPU(duration=400)
261
+ def run_agent(user_prompt: str, llama_model_id: str, duration: int, tts_model_name: str, music_length: int, ducking: bool, duck_level: int):
262
+ """
263
+ Runs the full workflow as an agent:
264
+ 1. Generates a script (voice-over, sound design, music suggestions) from a user prompt.
265
+ 2. Synthesizes a voice-over from the generated script.
266
+ 3. Generates a music track based on the music suggestions.
267
+ 4. Blends the voice and music tracks.
268
+ Returns a tuple with the generated script components, voice file, music file, and final blended audio.
269
+ """
270
+ # Step 1: Generate Script
271
+ voice_script, sound_design, music_suggestions = generate_script(user_prompt, llama_model_id, HF_TOKEN, duration)
272
+ # Step 2: Generate Voice-Over
273
+ voice_file = generate_voice(voice_script, tts_model_name)
274
+ # Step 3: Generate Music
275
+ music_file = generate_music(music_suggestions, music_length)
276
+ # Step 4: Blend Audio
277
+ blended_file = blend_audio(voice_file, music_file, ducking, duck_level)
278
+ return voice_script, sound_design, music_suggestions, voice_file, music_file, blended_file
279
 
280
  # ---------------------------------------------------------------------
281
  # Gradio Interface with Enhanced UI
 
322
  # Custom Header
323
  with gr.Row(elem_classes="header"):
324
  gr.Markdown("""
325
+ <h1>🎧 AI Promo Studio</h1>
326
+ <p>Your all-in-one AI solution for crafting engaging audio promos.</p>
327
  """)
328
 
329
  gr.Markdown("""
330
+ Welcome to **AI Promo Studio**! This platform leverages state-of-the-art AI models to help you generate:
331
 
332
  - **Script**: Generate a compelling voice-over script with LLaMA.
333
  - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS.
 
336
  """)
337
 
338
  with gr.Tabs():
339
+ # Tab 1: Script Generation
340
  with gr.Tab("πŸ“ Script Generation"):
341
  with gr.Row():
342
  user_prompt = gr.Textbox(
 
368
  outputs=[script_output, sound_design_output, music_suggestion_output],
369
  )
370
 
371
+ # Tab 2: Voice Synthesis
372
  with gr.Tab("🎀 Voice Synthesis"):
373
  gr.Markdown("Generate a natural-sounding voice-over using Coqui TTS.")
374
  selected_tts_model = gr.Dropdown(
 
390
  outputs=voice_audio_output,
391
  )
392
 
393
+ # Tab 3: Music Production
394
  with gr.Tab("🎢 Music Production"):
395
  gr.Markdown("Generate a custom music track using the **MusicGen Large** model.")
396
  audio_length = gr.Slider(
 
410
  outputs=[music_output],
411
  )
412
 
413
+ # Tab 4: Audio Blending
414
  with gr.Tab("🎚️ Audio Blending"):
415
  gr.Markdown("Blend your voice-over and music track. Music will be looped/truncated to match the voice duration. Enable ducking to lower the music during voice segments.")
416
  ducking_checkbox = gr.Checkbox(label="Enable Ducking?", value=True)
 
430
  outputs=blended_output
431
  )
432
 
433
+ # Tab 5: Agent – Full Workflow
434
+ with gr.Tab("πŸ€– Agent"):
435
+ gr.Markdown("Let the agent handle everything in one go: generate the script, synthesize voice, produce music, and blend the final ad.")
436
+ with gr.Row():
437
+ agent_prompt = gr.Textbox(
438
+ label="Ad Promo Idea",
439
+ placeholder="Enter your ad promo concept...",
440
+ lines=2
441
+ )
442
+ with gr.Row():
443
+ agent_llama_model_id = gr.Textbox(
444
+ label="LLaMA Model ID",
445
+ value="meta-llama/Meta-Llama-3-8B-Instruct",
446
+ placeholder="Enter a valid Hugging Face model ID"
447
+ )
448
+ agent_duration = gr.Slider(
449
+ label="Promo Duration (seconds)",
450
+ minimum=15, maximum=60, step=15, value=30
451
+ )
452
+ with gr.Row():
453
+ agent_tts_model = gr.Dropdown(
454
+ label="TTS Model",
455
+ choices=[
456
+ "tts_models/en/ljspeech/tacotron2-DDC",
457
+ "tts_models/en/ljspeech/vits",
458
+ "tts_models/en/sam/tacotron-DDC",
459
+ ],
460
+ value="tts_models/en/ljspeech/tacotron2-DDC",
461
+ multiselect=False
462
+ )
463
+ agent_music_length = gr.Slider(
464
+ label="Music Length (tokens)",
465
+ minimum=128, maximum=1024, step=64, value=512
466
+ )
467
+ with gr.Row():
468
+ agent_ducking = gr.Checkbox(label="Enable Ducking?", value=True)
469
+ agent_duck_level = gr.Slider(
470
+ label="Ducking Level (dB attenuation)",
471
+ minimum=0, maximum=20, step=1, value=10
472
+ )
473
+ agent_run_button = gr.Button("Run Agent", variant="primary")
474
+ agent_script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
475
+ agent_sound_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
476
+ agent_music_suggestions_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
477
+ agent_voice_audio = gr.Audio(label="Voice-Over (WAV)", type="filepath")
478
+ agent_music_audio = gr.Audio(label="Generated Music (WAV)", type="filepath")
479
+ agent_blended_audio = gr.Audio(label="Final Blended Output (WAV)", type="filepath")
480
+
481
+ agent_run_button.click(
482
+ fn=run_agent,
483
+ inputs=[agent_prompt, agent_llama_model_id, agent_duration, agent_tts_model, agent_music_length, agent_ducking, agent_duck_level],
484
+ outputs=[agent_script_output, agent_sound_output, agent_music_suggestions_output, agent_voice_audio, agent_music_audio, agent_blended_audio]
485
+ )
486
+
487
  # Footer
488
  gr.Markdown("""
489
  <div class="footer">
490
  <hr>
491
  Created with ❀️ by <a href="https://bilsimaging.com" target="_blank" style="color: #88aaff;">bilsimaging.com</a>
492
  <br>
493
+ <small>AI Promo Studio &copy; 2025</small>
494
  </div>
495
  """)
496
 
 
497
  gr.HTML("""
498
  <div style="text-align: center; margin-top: 1rem;">
499
  <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">