Bils commited on
Commit
a8a7982
·
verified ·
1 Parent(s): f2c044d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -369
app.py CHANGED
@@ -1,8 +1,6 @@
 
1
  import os
2
- import uuid
3
  import torch
4
- import numpy as np
5
- import gradio as gr
6
  from transformers import (
7
  AutoTokenizer,
8
  AutoModelForCausalLM,
@@ -15,23 +13,30 @@ from pydub import AudioSegment
15
  from dotenv import load_dotenv
16
  import tempfile
17
  import spaces
 
 
18
  from TTS.api import TTS
19
 
20
- # -----------------------------------------------------------
21
- # Initialization & Environment Setup
22
- # -----------------------------------------------------------
23
  load_dotenv()
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
 
26
- # -----------------------------------------------------------
27
- # Model Cache Management
28
- # -----------------------------------------------------------
29
  LLAMA_PIPELINES = {}
30
  MUSICGEN_MODELS = {}
31
  TTS_MODELS = {}
32
 
 
 
 
33
  def get_llama_pipeline(model_id: str, token: str):
34
- """Load and cache the LLaMA text-generation pipeline."""
 
 
35
  if model_id in LLAMA_PIPELINES:
36
  return LLAMA_PIPELINES[model_id]
37
 
@@ -47,434 +52,339 @@ def get_llama_pipeline(model_id: str, token: str):
47
  LLAMA_PIPELINES[model_id] = text_pipeline
48
  return text_pipeline
49
 
 
50
  def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
51
- """Load and cache the MusicGen model and processor."""
 
 
 
52
  if model_key in MUSICGEN_MODELS:
53
  return MUSICGEN_MODELS[model_key]
54
 
55
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
56
  processor = AutoProcessor.from_pretrained(model_key)
 
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
58
  model.to(device)
59
  MUSICGEN_MODELS[model_key] = (model, processor)
60
  return model, processor
61
 
 
62
  def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
63
- """Load and cache the TTS model."""
 
 
64
  if model_name in TTS_MODELS:
65
  return TTS_MODELS[model_name]
 
66
  tts_model = TTS(model_name)
67
  TTS_MODELS[model_name] = tts_model
68
  return tts_model
69
 
70
- # -----------------------------------------------------------
71
- # Core Functionality
72
- # -----------------------------------------------------------
 
73
  @spaces.GPU(duration=100)
74
  def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
75
  """
76
- Generate a professional promo script including a voice-over script,
77
- sound design suggestions, and music recommendations.
78
  """
79
  try:
80
  text_pipeline = get_llama_pipeline(model_id, token)
81
- # Updated prompt to instruct the model to output sections with explicit headers.
82
  system_prompt = (
83
- f"You are a professional audio producer creating {duration}-second content. "
84
- "Please generate the following three sections exactly as shown:\n\n"
85
- "Voice-Over Script: [A clear and concise script for the voiceover.]\n"
86
- "Sound Design Suggestions: [Specific ideas, effects, and ambience recommendations.]\n"
87
- "Music Suggestions: [Recommendations for music style, genre, and tempo.]\n\n"
88
- "Make sure each section starts with its header exactly."
89
  )
90
-
91
- full_prompt = f"{system_prompt}\nClient brief: {user_prompt}\nOutput:"
92
-
93
  with torch.inference_mode():
94
  result = text_pipeline(
95
- full_prompt,
96
- max_new_tokens=400,
97
  do_sample=True,
98
- temperature=0.7,
99
- top_p=0.9
100
  )
101
 
102
- generated_text = result[0]["generated_text"].split("Output:")[-1].strip()
103
-
104
- # Parse the output into the three expected sections.
105
- sections = {
106
- "Voice-Over Script:": "",
107
- "Sound Design Suggestions:": "",
108
- "Music Suggestions:": ""
109
- }
110
-
111
- current_section = None
112
- for line in generated_text.split('\n'):
113
- for section in sections:
114
- if section in line:
115
- current_section = section
116
- # Remove header from the line.
117
- line = line.replace(section, '').strip()
118
- break
119
- if current_section:
120
- sections[current_section] += line + '\n'
121
-
122
- return (
123
- sections["Voice-Over Script:"].strip() or "No script generated",
124
- sections["Sound Design Suggestions:"].strip() or "No sound design suggestions",
125
- sections["Music Suggestions:"].strip() or "No music suggestions"
126
- )
 
 
 
 
 
 
 
 
127
 
128
  except Exception as e:
129
- return f"Error: {str(e)}", "", ""
130
 
 
 
 
 
131
  @spaces.GPU(duration=100)
132
- def generate_voice(script: str, tts_model_name: str):
133
  """
134
- Generate full voice-over audio from the provided script using a TTS model.
 
135
  """
136
  try:
137
  if not script.strip():
138
- return None
 
139
  tts_model = get_tts_model(tts_model_name)
140
- # Create a unique temporary file name for the output.
141
- output_path = os.path.join(tempfile.gettempdir(), f"voice_{uuid.uuid4().hex}.wav")
 
142
  tts_model.tts_to_file(text=script, file_path=output_path)
143
  return output_path
144
- except Exception as e:
145
- print(f"Voice generation error: {e}")
146
- return None
147
 
148
- @spaces.GPU(duration=100)
149
- def generate_voice_preview(script: str, tts_model_name: str):
150
- """
151
- Generate a short preview of the voice-over by taking the first 100 words.
152
- """
153
- try:
154
- if not script.strip():
155
- return None
156
- words = script.split()
157
- preview_text = ' '.join(words[:100]) if len(words) > 100 else script
158
- return generate_voice(preview_text, tts_model_name)
159
  except Exception as e:
160
- print(f"Voice preview error: {e}")
161
- return None
162
 
 
 
 
 
163
  @spaces.GPU(duration=100)
164
  def generate_music(prompt: str, audio_length: int):
165
  """
166
- Generate music audio from a text prompt using the MusicGen model.
 
167
  """
168
  try:
169
- model, processor = get_musicgen_model()
 
 
 
 
 
170
  device = "cuda" if torch.cuda.is_available() else "cpu"
171
- inputs = processor(text=[prompt], padding=True, return_tensors="pt").to(device)
172
-
173
  with torch.inference_mode():
174
- outputs = model.generate(**inputs, max_new_tokens=audio_length)
175
-
176
- # Assuming outputs[0, 0] holds the generated audio waveform.
177
  audio_data = outputs[0, 0].cpu().numpy()
178
- # Prevent division by zero during normalization.
179
- max_val = np.max(np.abs(audio_data))
180
- if max_val == 0:
181
- normalized_audio = audio_data.astype("int16")
182
- else:
183
- normalized_audio = (audio_data / max_val * 32767).astype("int16")
184
- output_path = os.path.join(tempfile.gettempdir(), f"music_{uuid.uuid4().hex}.wav")
185
  write(output_path, 44100, normalized_audio)
 
186
  return output_path
 
187
  except Exception as e:
188
- print(f"Music generation error: {e}")
189
- return None
190
 
 
 
 
191
  @spaces.GPU(duration=100)
192
- def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int):
193
  """
194
- Blend the generated voice and music audio files.
195
- If ducking is enabled, lower the music volume during the voice segments.
 
 
 
196
  """
197
  try:
 
 
 
198
  voice = AudioSegment.from_wav(voice_path)
199
  music = AudioSegment.from_wav(music_path)
200
-
201
- # Loop the music track if it's shorter than the voice track.
202
- if len(music) < len(voice):
203
- loops_needed = (len(voice) // len(music)) + 1
204
- music = music * loops_needed
205
- music = music[:len(voice)]
206
-
 
 
 
 
 
 
 
 
 
 
207
  if ducking:
 
208
  ducked_music = music - duck_level
 
209
  final_audio = ducked_music.overlay(voice)
210
  else:
 
211
  final_audio = music.overlay(voice)
212
-
213
- output_path = os.path.join(tempfile.gettempdir(), f"final_mix_{uuid.uuid4().hex}.wav")
214
  final_audio.export(output_path, format="wav")
215
  return output_path
 
216
  except Exception as e:
217
- print(f"Mixing error: {e}")
218
- return None
219
-
220
- # -----------------------------------------------------------
221
- # Enhanced UI Components
222
- # -----------------------------------------------------------
223
- custom_css = """
224
- #main-container {
225
- max-width: 1200px;
226
- margin: 0 auto;
227
- padding: 20px;
228
- background: #f0f9fb;
229
- border-radius: 15px;
230
- box-shadow: 0 4px 6px rgba(0,0,0,0.05);
231
- }
232
-
233
- .header {
234
- text-align: center;
235
- padding: 2em;
236
- background: linear-gradient(135deg, #2a9d8f 0%, #457b9d 100%);
237
- color: white;
238
- border-radius: 15px;
239
- margin-bottom: 2em;
240
- border: 1px solid #264653;
241
- }
242
-
243
- .tab-nav {
244
- background: none !important;
245
- border: none !important;
246
- }
247
-
248
- .tab-button {
249
- padding: 1em 2em !important;
250
- border-radius: 8px !important;
251
- margin: 0 5px !important;
252
- transition: all 0.3s ease !important;
253
- background: #e9f5f4 !important;
254
- border: 1px solid #a8dadc !important;
255
- color: #1d3557 !important;
256
- }
257
-
258
- .tab-button:hover {
259
- transform: translateY(-2px);
260
- box-shadow: 0 3px 6px rgba(42,157,143,0.2);
261
- background: #caf0f8 !important;
262
- }
263
-
264
- .dark-btn {
265
- background: linear-gradient(135deg, #457b9d 0%, #2a9d8f 100%) !important;
266
- color: white !important;
267
- border: none !important;
268
- padding: 12px 24px !important;
269
- border-radius: 8px !important;
270
- transition: transform 0.2s ease !important;
271
- }
272
-
273
- .dark-btn:hover {
274
- transform: scale(1.02);
275
- box-shadow: 0 3px 8px rgba(42,157,143,0.3);
276
- }
277
-
278
- .output-card {
279
- background: #f8fbfe !important;
280
- border-radius: 10px !important;
281
- padding: 20px !important;
282
- box-shadow: 0 2px 4px rgba(69,123,157,0.1) !important;
283
- border: 1px solid #e2e8f0;
284
- }
285
-
286
- .progress-indicator {
287
- color: #457b9d;
288
- font-style: italic;
289
- margin-top: 10px;
290
- }
291
-
292
- /* Additional Color Elements */
293
- h1, h2, h3 {
294
- color: #1d3557 !important;
295
- }
296
-
297
- audio {
298
- border: 1px solid #a8dadc !important;
299
- border-radius: 8px !important;
300
- }
301
-
302
- .slider-handle {
303
- background: #2a9d8f !important;
304
- }
305
- """
306
-
307
- with gr.Blocks(css=custom_css, theme=gr.themes.Default()) as demo:
308
- with gr.Column(elem_id="main-container"):
309
- # Header Section
310
- with gr.Column(elem_classes="header"):
311
- gr.Markdown("""
312
- # 🎙️ AI Promo Studio
313
- **Professional Audio Production Suite Powered by AI**
314
- """)
315
-
316
- # Main Workflow Tabs
317
- with gr.Tabs(elem_classes="tab-nav"):
318
- # Script Generation Tab
319
- with gr.Tab("📝 Script Design", elem_classes="tab-button"):
320
- with gr.Row(equal_height=False):
321
- with gr.Column(scale=2):
322
- gr.Markdown("### 🎯 Project Brief")
323
- user_prompt = gr.Textbox(
324
- label="Describe your promo concept",
325
- placeholder="e.g., 'An intense 30-second movie trailer intro with epic orchestral music and dramatic sound effects...'",
326
- lines=4
327
- )
328
- with gr.Row():
329
- duration = gr.Slider(
330
- label="Duration (seconds)",
331
- minimum=15,
332
- maximum=120,
333
- step=15,
334
- value=30,
335
- interactive=True
336
- )
337
- llama_model_id = gr.Dropdown(
338
- label="AI Model",
339
- choices=["meta-llama/Meta-Llama-3-8B-Instruct"],
340
- value="meta-llama/Meta-Llama-3-8B-Instruct",
341
- interactive=True
342
- )
343
- generate_btn = gr.Button("Generate Script 🚀", elem_classes="dark-btn")
344
-
345
- with gr.Column(scale=1, elem_classes="output-card"):
346
- gr.Markdown("### 📄 Generated Content")
347
- script_output = gr.Textbox(label="Voice Script", lines=6)
348
- sound_design_output = gr.Textbox(label="Sound Design", lines=3)
349
- music_suggestion_output = gr.Textbox(label="Music Style", lines=3)
350
-
351
- # Voice Production Tab
352
- with gr.Tab("🎙️ Voice Production", elem_classes="tab-button"):
353
- with gr.Row():
354
- with gr.Column(scale=1):
355
- gr.Markdown("### 🔊 Voice Settings")
356
- tts_model = gr.Dropdown(
357
- label="Voice Model",
358
- choices=[
359
- "tts_models/en/ljspeech/tacotron2-DDC",
360
- "tts_models/en/ljspeech/vits",
361
- "tts_models/en/sam/tacotron-DDC"
362
- ],
363
- value="tts_models/en/ljspeech/tacotron2-DDC",
364
- interactive=True
365
- )
366
- with gr.Row():
367
- voice_preview_btn = gr.Button("Preview Sample", elem_classes="dark-btn")
368
- voice_generate_btn = gr.Button("Generate Full Voiceover", elem_classes="dark-btn")
369
- with gr.Column(scale=1, elem_classes="output-card"):
370
- gr.Markdown("### 🎧 Voice Preview")
371
- voice_audio = gr.Audio(
372
- label="Generated Voice",
373
- interactive=False,
374
- waveform_options={"show_controls": True}
375
- )
376
-
377
- # Music Production Tab
378
- with gr.Tab("🎵 Music Design", elem_classes="tab-button"):
379
- with gr.Row():
380
- with gr.Column(scale=1):
381
- gr.Markdown("### 🎹 Music Parameters")
382
- audio_length = gr.Slider(
383
- label="Generation Length",
384
- minimum=256,
385
- maximum=1024,
386
- step=64,
387
- value=512,
388
- info="Higher values = longer generation time"
389
- )
390
- music_generate_btn = gr.Button("Generate Music Track", elem_classes="dark-btn")
391
- with gr.Column(scale=1, elem_classes="output-card"):
392
- gr.Markdown("### 🎶 Music Preview")
393
- music_output = gr.Audio(
394
- label="Generated Music",
395
- interactive=False,
396
- waveform_options={"show_controls": True}
397
- )
398
-
399
- # Final Mix Tab
400
- with gr.Tab("🔊 Final Mix", elem_classes="tab-button"):
401
- with gr.Row():
402
- with gr.Column(scale=1):
403
- gr.Markdown("### 🎚️ Mixing Console")
404
- ducking_enabled = gr.Checkbox(
405
- label="Enable Voice Ducking",
406
- value=True,
407
- info="Automatically lower music during voice segments"
408
- )
409
- duck_level = gr.Slider(
410
- label="Ducking Intensity (dB)",
411
- minimum=3,
412
- maximum=20,
413
- step=1,
414
- value=10
415
- )
416
- mix_btn = gr.Button("Generate Final Mix", elem_classes="dark-btn")
417
- with gr.Column(scale=1, elem_classes="output-card"):
418
- gr.Markdown("### 🎧 Final Production")
419
- final_mix = gr.Audio(
420
- label="Mixed Output",
421
- interactive=False,
422
- waveform_options={"show_controls": True}
423
- )
424
-
425
- # Footer Section
426
- with gr.Column(elem_classes="output-card"):
427
- gr.Markdown("""
428
- <div style="text-align: center; padding: 1.5em 0;">
429
- <a href="https://bilsimaging.com" target="_blank">
430
- <img src="https://bilsimaging.com/logo.png" alt="Bils Imaging" style="height: 35px; margin-right: 15px;">
431
- </a>
432
- <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
433
- <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" />
434
- </a>
435
- </div>
436
- <p style="text-align: center; color: #666; font-size: 0.9em;">
437
- Professional Audio Production Suite v2.1 © 2024 | Bils Imaging
438
- </p>
439
- """)
440
-
441
- # -----------------------------------------------------------
442
- # Event Handling
443
- # -----------------------------------------------------------
444
- # Hidden textbox for HF_TOKEN (its value is set via the environment variable).
445
- hf_token_hidden = gr.Textbox(value=HF_TOKEN, visible=False)
446
-
447
- generate_btn.click(
448
- generate_script,
449
- inputs=[user_prompt, llama_model_id, hf_token_hidden, duration],
450
- outputs=[script_output, sound_design_output, music_suggestion_output]
451
- )
452
-
453
- # Voice preview: generates a trimmed version of the script.
454
- voice_preview_btn.click(
455
- generate_voice_preview,
456
- inputs=[script_output, tts_model],
457
- outputs=voice_audio
458
- )
459
-
460
- # Full voice generation using the complete script.
461
- voice_generate_btn.click(
462
- generate_voice,
463
- inputs=[script_output, tts_model],
464
- outputs=voice_audio
465
- )
466
-
467
- music_generate_btn.click(
468
- generate_music,
469
- inputs=[music_suggestion_output, audio_length],
470
- outputs=music_output
471
- )
472
-
473
- mix_btn.click(
474
- blend_audio,
475
- inputs=[voice_audio, music_output, ducking_enabled, duck_level],
476
- outputs=final_mix
477
- )
478
 
479
- if __name__ == "__main__":
480
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
  import os
 
3
  import torch
 
 
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
 
13
  from dotenv import load_dotenv
14
  import tempfile
15
  import spaces
16
+
17
+ # Coqui TTS
18
  from TTS.api import TTS
19
 
20
+ # ---------------------------------------------------------------------
21
+ # Load Environment Variables
22
+ # ---------------------------------------------------------------------
23
  load_dotenv()
24
  HF_TOKEN = os.getenv("HF_TOKEN")
25
 
26
+ # ---------------------------------------------------------------------
27
+ # Global Model Caches
28
+ # ---------------------------------------------------------------------
29
  LLAMA_PIPELINES = {}
30
  MUSICGEN_MODELS = {}
31
  TTS_MODELS = {}
32
 
33
+ # ---------------------------------------------------------------------
34
+ # Helper Functions
35
+ # ---------------------------------------------------------------------
36
  def get_llama_pipeline(model_id: str, token: str):
37
+ """
38
+ Returns a cached LLaMA pipeline if available; otherwise, loads it.
39
+ """
40
  if model_id in LLAMA_PIPELINES:
41
  return LLAMA_PIPELINES[model_id]
42
 
 
52
  LLAMA_PIPELINES[model_id] = text_pipeline
53
  return text_pipeline
54
 
55
+
56
  def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
57
+ """
58
+ Returns a cached MusicGen model if available; otherwise, loads it.
59
+ Uses the 'large' variant for higher quality outputs.
60
+ """
61
  if model_key in MUSICGEN_MODELS:
62
  return MUSICGEN_MODELS[model_key]
63
 
64
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
65
  processor = AutoProcessor.from_pretrained(model_key)
66
+
67
  device = "cuda" if torch.cuda.is_available() else "cpu"
68
  model.to(device)
69
  MUSICGEN_MODELS[model_key] = (model, processor)
70
  return model, processor
71
 
72
+
73
  def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
74
+ """
75
+ Returns a cached TTS model if available; otherwise, loads it.
76
+ """
77
  if model_name in TTS_MODELS:
78
  return TTS_MODELS[model_name]
79
+
80
  tts_model = TTS(model_name)
81
  TTS_MODELS[model_name] = tts_model
82
  return tts_model
83
 
84
+
85
+ # ---------------------------------------------------------------------
86
+ # Script Generation Function
87
+ # ---------------------------------------------------------------------
88
  @spaces.GPU(duration=100)
89
  def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
90
  """
91
+ Generates a script, sound design suggestions, and music ideas from a user prompt.
92
+ Returns a tuple of strings: (voice_script, sound_design, music_suggestions).
93
  """
94
  try:
95
  text_pipeline = get_llama_pipeline(model_id, token)
96
+
97
  system_prompt = (
98
+ "You are an expert radio imaging producer specializing in sound design and music. "
99
+ f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
100
+ "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'.\n"
101
+ "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
102
+ "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
 
103
  )
104
+ combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
105
+
 
106
  with torch.inference_mode():
107
  result = text_pipeline(
108
+ combined_prompt,
109
+ max_new_tokens=300,
110
  do_sample=True,
111
+ temperature=0.8
 
112
  )
113
 
114
+ generated_text = result[0]["generated_text"]
115
+ if "Output:" in generated_text:
116
+ generated_text = generated_text.split("Output:")[-1].strip()
117
+
118
+ # Default placeholders
119
+ voice_script = "No voice-over script found."
120
+ sound_design = "No sound design suggestions found."
121
+ music_suggestions = "No music suggestions found."
122
+
123
+ # Voice-Over Script
124
+ if "Voice-Over Script:" in generated_text:
125
+ parts = generated_text.split("Voice-Over Script:")
126
+ voice_script_part = parts[1]
127
+ if "Sound Design Suggestions:" in voice_script_part:
128
+ voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip()
129
+ else:
130
+ voice_script = voice_script_part.strip()
131
+
132
+ # Sound Design
133
+ if "Sound Design Suggestions:" in generated_text:
134
+ parts = generated_text.split("Sound Design Suggestions:")
135
+ sound_design_part = parts[1]
136
+ if "Music Suggestions:" in sound_design_part:
137
+ sound_design = sound_design_part.split("Music Suggestions:")[0].strip()
138
+ else:
139
+ sound_design = sound_design_part.strip()
140
+
141
+ # Music Suggestions
142
+ if "Music Suggestions:" in generated_text:
143
+ parts = generated_text.split("Music Suggestions:")
144
+ music_suggestions = parts[1].strip()
145
+
146
+ return voice_script, sound_design, music_suggestions
147
 
148
  except Exception as e:
149
+ return f"Error generating script: {e}", "", ""
150
 
151
+
152
+ # ---------------------------------------------------------------------
153
+ # Voice-Over Generation Function
154
+ # ---------------------------------------------------------------------
155
  @spaces.GPU(duration=100)
156
+ def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
157
  """
158
+ Generates a voice-over from the provided script using the Coqui TTS model.
159
+ Returns the file path to the generated .wav file.
160
  """
161
  try:
162
  if not script.strip():
163
+ return "Error: No script provided."
164
+
165
  tts_model = get_tts_model(tts_model_name)
166
+
167
+ # Generate and save voice
168
+ output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
169
  tts_model.tts_to_file(text=script, file_path=output_path)
170
  return output_path
 
 
 
171
 
 
 
 
 
 
 
 
 
 
 
 
172
  except Exception as e:
173
+ return f"Error generating voice: {e}"
 
174
 
175
+
176
+ # ---------------------------------------------------------------------
177
+ # Music Generation Function
178
+ # ---------------------------------------------------------------------
179
  @spaces.GPU(duration=100)
180
  def generate_music(prompt: str, audio_length: int):
181
  """
182
+ Generates music from the 'facebook/musicgen-large' model based on the prompt.
183
+ Returns the file path to the generated .wav file.
184
  """
185
  try:
186
+ if not prompt.strip():
187
+ return "Error: No music suggestion provided."
188
+
189
+ model_key = "facebook/musicgen-large"
190
+ musicgen_model, musicgen_processor = get_musicgen_model(model_key)
191
+
192
  device = "cuda" if torch.cuda.is_available() else "cpu"
193
+ inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
194
+
195
  with torch.inference_mode():
196
+ outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
197
+
 
198
  audio_data = outputs[0, 0].cpu().numpy()
199
+ normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
200
+
201
+ output_path = f"{tempfile.gettempdir()}/musicgen_large_generated_music.wav"
 
 
 
 
202
  write(output_path, 44100, normalized_audio)
203
+
204
  return output_path
205
+
206
  except Exception as e:
207
+ return f"Error generating music: {e}"
208
+
209
 
210
+ # ---------------------------------------------------------------------
211
+ # Audio Blending with Duration Sync & Ducking
212
+ # ---------------------------------------------------------------------
213
  @spaces.GPU(duration=100)
214
+ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10):
215
  """
216
+ Blends two audio files (voice and music).
217
+ 1. If music < voice, loops the music until it meets/exceeds the voice duration.
218
+ 2. If music > voice, trims music to the voice duration.
219
+ 3. If ducking=True, the music is attenuated by 'duck_level' dB while the voice is playing.
220
+ Returns the file path to the blended .wav file.
221
  """
222
  try:
223
+ if not os.path.isfile(voice_path) or not os.path.isfile(music_path):
224
+ return "Error: Missing audio files for blending."
225
+
226
  voice = AudioSegment.from_wav(voice_path)
227
  music = AudioSegment.from_wav(music_path)
228
+
229
+ voice_len = len(voice) # in milliseconds
230
+ music_len = len(music) # in milliseconds
231
+
232
+ # 1) If the music is shorter than the voice, loop it:
233
+ if music_len < voice_len:
234
+ looped_music = AudioSegment.empty()
235
+ # Keep appending until we exceed voice length
236
+ while len(looped_music) < voice_len:
237
+ looped_music += music
238
+ music = looped_music
239
+
240
+ # 2) If the music is longer than the voice, truncate it:
241
+ if len(music) > voice_len:
242
+ music = music[:voice_len]
243
+
244
+ # Now music and voice are the same length
245
  if ducking:
246
+ # Step 1: Reduce music dB while voice is playing
247
  ducked_music = music - duck_level
248
+ # Step 2: Overlay voice on top of ducked music
249
  final_audio = ducked_music.overlay(voice)
250
  else:
251
+ # No ducking, just overlay
252
  final_audio = music.overlay(voice)
253
+
254
+ output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
255
  final_audio.export(output_path, format="wav")
256
  return output_path
257
+
258
  except Exception as e:
259
+ return f"Error blending audio: {e}"
260
+
261
+
262
+ # ---------------------------------------------------------------------
263
+ # Gradio Interface
264
+ # ---------------------------------------------------------------------
265
+ with gr.Blocks() as demo:
266
+ gr.Markdown("""
267
+ # 🎧 AI Promo Studio
268
+ Welcome to **AI Promo Studio**, your all-in-one solution for creating professional, engaging audio promos with minimal effort!
269
+
270
+ This next-generation platform uses powerful AI models to handle:
271
+ - **Script Generation**: Craft concise and impactful copy with LLaMA.
272
+ - **Voice Synthesis**: Convert text into natural-sounding voice-overs using Coqui TTS.
273
+ - **Music Production**: Generate custom music tracks with MusicGen Large for sound bed.
274
+ - **Seamless Blending**: Easily combine voice and music—loop or trim tracks to match your desired promo length, with optional ducking to keep the voice front and center.
275
+
276
+ Whether you’re a radio producer, podcaster, or content creator, **AI Promo Studio** streamlines your entire production pipeline—cutting hours of manual editing down to a few clicks.
277
+ """)
278
+
279
+
280
+ with gr.Tabs():
281
+ # Step 1: Generate Script
282
+ with gr.Tab("Step 1: Generate Script"):
283
+ with gr.Row():
284
+ user_prompt = gr.Textbox(
285
+ label="Promo Idea",
286
+ placeholder="E.g., A 30-second promo for a morning show...",
287
+ lines=2
288
+ )
289
+ llama_model_id = gr.Textbox(
290
+ label="LLaMA Model ID",
291
+ value="meta-llama/Meta-Llama-3-8B-Instruct",
292
+ placeholder="Enter a valid Hugging Face model ID"
293
+ )
294
+ duration = gr.Slider(
295
+ label="Desired Promo Duration (seconds)",
296
+ minimum=15,
297
+ maximum=60,
298
+ step=15,
299
+ value=30
300
+ )
301
+
302
+ generate_script_button = gr.Button("Generate Script")
303
+ script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
304
+ sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
305
+ music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
306
+
307
+ generate_script_button.click(
308
+ fn=lambda user_prompt, model_id, dur: generate_script(user_prompt, model_id, HF_TOKEN, dur),
309
+ inputs=[user_prompt, llama_model_id, duration],
310
+ outputs=[script_output, sound_design_output, music_suggestion_output],
311
+ )
312
+
313
+ # Step 2: Generate Voice
314
+ with gr.Tab("Step 2: Generate Voice"):
315
+ gr.Markdown("Generate the voice-over using a Coqui TTS model.")
316
+ selected_tts_model = gr.Dropdown(
317
+ label="TTS Model",
318
+ choices=[
319
+ "tts_models/en/ljspeech/tacotron2-DDC",
320
+ "tts_models/en/ljspeech/vits",
321
+ "tts_models/en/sam/tacotron-DDC",
322
+ ],
323
+ value="tts_models/en/ljspeech/tacotron2-DDC",
324
+ multiselect=False
325
+ )
326
+ generate_voice_button = gr.Button("Generate Voice-Over")
327
+ voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
328
+
329
+ generate_voice_button.click(
330
+ fn=lambda script, tts_model: generate_voice(script, tts_model),
331
+ inputs=[script_output, selected_tts_model],
332
+ outputs=voice_audio_output,
333
+ )
334
+
335
+ # Step 3: Generate Music (MusicGen Large)
336
+ with gr.Tab("Step 3: Generate Music"):
337
+ gr.Markdown("Generate a music track with the **MusicGen Large** model.")
338
+ audio_length = gr.Slider(
339
+ label="Music Length (tokens)",
340
+ minimum=128,
341
+ maximum=1024,
342
+ step=64,
343
+ value=512,
344
+ info="Increase tokens for longer audio, but be mindful of inference time."
345
+ )
346
+ generate_music_button = gr.Button("Generate Music")
347
+ music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
348
+
349
+ generate_music_button.click(
350
+ fn=lambda music_suggestion, length: generate_music(music_suggestion, length),
351
+ inputs=[music_suggestion_output, audio_length],
352
+ outputs=[music_output],
353
+ )
354
+
355
+ # Step 4: Blend Audio (Loop/Trim + Ducking)
356
+ with gr.Tab("Step 4: Blend Audio"):
357
+ gr.Markdown("**Music** will be looped or trimmed to match **Voice** duration, then optionally ducked.")
358
+ ducking_checkbox = gr.Checkbox(label="Enable Ducking?", value=True)
359
+ duck_level_slider = gr.Slider(
360
+ label="Ducking Level (dB attenuation)",
361
+ minimum=0,
362
+ maximum=20,
363
+ step=1,
364
+ value=10
365
+ )
366
+ blend_button = gr.Button("Blend Voice + Music")
367
+ blended_output = gr.Audio(label="Final Blended Output (WAV)", type="filepath")
368
+
369
+ blend_button.click(
370
+ fn=blend_audio,
371
+ inputs=[voice_audio_output, music_output, ducking_checkbox, duck_level_slider],
372
+ outputs=blended_output
373
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
+ # Footer
376
+ gr.Markdown("""
377
+ <hr>
378
+ <p style="text-align: center; font-size: 0.9em;">
379
+ Created with ❤️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
380
+ </p>
381
+ """)
382
+
383
+ # Visitor Badge
384
+ gr.HTML("""
385
+ <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
386
+ <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" />
387
+ </a>
388
+ """)
389
+
390
+ demo.launch(debug=True)