Bils commited on
Commit
559ca26
·
verified ·
1 Parent(s): 2925d53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +148 -70
app.py CHANGED
@@ -13,22 +13,22 @@ from pydub import AudioSegment
13
  from dotenv import load_dotenv
14
  import tempfile
15
  import spaces
 
 
16
  from TTS.api import TTS
17
- from TTS.utils.synthesizer import Synthesizer
18
 
19
  # ---------------------------------------------------------------------
20
  # Load Environment Variables
21
  # ---------------------------------------------------------------------
22
  load_dotenv()
23
- HF_TOKEN = os.getenv("HF_TOKEN")
24
 
25
  # ---------------------------------------------------------------------
26
  # Global Model Caches
27
  # ---------------------------------------------------------------------
28
- # We store models/pipelines in global variables for reuse,
29
- # so they are only loaded once.
30
  LLAMA_PIPELINES = {}
31
  MUSICGEN_MODELS = {}
 
32
 
33
  # ---------------------------------------------------------------------
34
  # Helper Functions
@@ -36,12 +36,10 @@ MUSICGEN_MODELS = {}
36
  def get_llama_pipeline(model_id: str, token: str):
37
  """
38
  Returns a cached LLaMA pipeline if available; otherwise, loads it.
39
- This significantly reduces loading time for repeated calls.
40
  """
41
  if model_id in LLAMA_PIPELINES:
42
  return LLAMA_PIPELINES[model_id]
43
 
44
- # Load new pipeline and store in cache
45
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
46
  model = AutoModelForCausalLM.from_pretrained(
47
  model_id,
@@ -55,14 +53,14 @@ def get_llama_pipeline(model_id: str, token: str):
55
  return text_pipeline
56
 
57
 
58
- def get_musicgen_model(model_key: str = "facebook/musicgen-medium"):
59
  """
60
  Returns a cached MusicGen model if available; otherwise, loads it.
 
61
  """
62
  if model_key in MUSICGEN_MODELS:
63
  return MUSICGEN_MODELS[model_key]
64
 
65
- # Load new MusicGen model and store in cache
66
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
67
  processor = AutoProcessor.from_pretrained(model_key)
68
 
@@ -73,6 +71,18 @@ def get_musicgen_model(model_key: str = "facebook/musicgen-medium"):
73
  return model, processor
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # ---------------------------------------------------------------------
77
  # Script Generation Function
78
  # ---------------------------------------------------------------------
@@ -85,7 +95,6 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
85
  try:
86
  text_pipeline = get_llama_pipeline(model_id, token)
87
 
88
- # System prompt with clear structure instructions
89
  system_prompt = (
90
  "You are an expert radio imaging producer specializing in sound design and music. "
91
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
@@ -93,10 +102,8 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
93
  "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
94
  "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
95
  )
96
-
97
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
98
 
99
- # Use inference mode for efficient forward passes
100
  with torch.inference_mode():
101
  result = text_pipeline(
102
  combined_prompt,
@@ -105,38 +112,37 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
105
  temperature=0.8
106
  )
107
 
108
- # LLaMA pipeline returns a list of dicts with "generated_text"
109
  generated_text = result[0]["generated_text"]
110
-
111
- # Basic parsing to isolate everything after "Output:"
112
- # (in case the model repeated your system prompt).
113
  if "Output:" in generated_text:
114
  generated_text = generated_text.split("Output:")[-1].strip()
115
 
116
- # Extract sections based on known prefixes
117
  voice_script = "No voice-over script found."
118
  sound_design = "No sound design suggestions found."
119
  music_suggestions = "No music suggestions found."
120
 
 
121
  if "Voice-Over Script:" in generated_text:
122
  parts = generated_text.split("Voice-Over Script:")
123
- if len(parts) > 1:
124
- # Everything after "Voice-Over Script:" up until next prefix
125
- voice_script_part = parts[1]
126
- voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip() \
127
- if "Sound Design Suggestions:" in voice_script_part else voice_script_part.strip()
128
 
 
129
  if "Sound Design Suggestions:" in generated_text:
130
  parts = generated_text.split("Sound Design Suggestions:")
131
- if len(parts) > 1:
132
- sound_design_part = parts[1]
133
- sound_design = sound_design_part.split("Music Suggestions:")[0].strip() \
134
- if "Music Suggestions:" in sound_design_part else sound_design_part.strip()
 
135
 
 
136
  if "Music Suggestions:" in generated_text:
137
  parts = generated_text.split("Music Suggestions:")
138
- if len(parts) > 1:
139
- music_suggestions = parts[1].strip()
140
 
141
  return voice_script, sound_design, music_suggestions
142
 
@@ -145,46 +151,55 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
145
 
146
 
147
  # ---------------------------------------------------------------------
148
- # Voice-Over Generation Function (Inactive)
149
  # ---------------------------------------------------------------------
150
  @spaces.GPU(duration=100)
151
- def generate_voice(script: str, speaker: str = "default"):
152
  """
153
- Placeholder for future voice-over generation functionality.
 
154
  """
155
  try:
156
- return "Voice-over generation is currently inactive."
 
 
 
 
 
 
 
 
 
157
  except Exception as e:
158
- return f"Error: {e}"
159
 
160
 
161
  # ---------------------------------------------------------------------
162
- # Music Generation Function
163
  # ---------------------------------------------------------------------
164
  @spaces.GPU(duration=100)
165
  def generate_music(prompt: str, audio_length: int):
166
  """
167
- Generates music from the 'facebook/musicgen-medium' model based on the prompt.
168
  Returns the file path to the generated .wav file.
169
  """
170
  try:
171
- model_key = "facebook/musicgen-medium"
 
 
 
172
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
173
 
174
  device = "cuda" if torch.cuda.is_available() else "cpu"
175
- # Prepare input
176
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
177
 
178
- # Generate music within inference mode
179
  with torch.inference_mode():
180
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
181
 
182
  audio_data = outputs[0, 0].cpu().numpy()
183
- # Normalize audio to int16 format
184
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
185
 
186
- # Save generated music to a temp file
187
- output_path = f"{tempfile.gettempdir()}/musicgen_medium_generated_music.wav"
188
  write(output_path, 44100, normalized_audio)
189
 
190
  return output_path
@@ -194,16 +209,46 @@ def generate_music(prompt: str, audio_length: int):
194
 
195
 
196
  # ---------------------------------------------------------------------
197
- # Audio Blending Function (Inactive)
198
  # ---------------------------------------------------------------------
199
- def blend_audio(voice_path: str, music_path: str, ducking: bool):
 
200
  """
201
- Placeholder for future audio blending functionality with optional ducking.
 
 
202
  """
203
  try:
204
- return "Audio blending functionality is currently inactive."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  except Exception as e:
206
- return f"Error: {e}"
207
 
208
 
209
  # ---------------------------------------------------------------------
@@ -211,9 +256,15 @@ def blend_audio(voice_path: str, music_path: str, ducking: bool):
211
  # ---------------------------------------------------------------------
212
  with gr.Blocks() as demo:
213
  gr.Markdown("""
214
- # 🎧 AI Promo Studio 🚀
215
- Welcome to **AI Promo Studio**, your one-stop solution for creating stunning and professional radio promos with ease!
216
- Whether you're a sound designer, radio producer, or content creator, our AI-driven tools, powered by advanced LLM Llama models, empower you to bring your vision to life in just a few steps.
 
 
 
 
 
 
217
  """)
218
 
219
  with gr.Tabs():
@@ -249,24 +300,39 @@ with gr.Blocks() as demo:
249
  outputs=[script_output, sound_design_output, music_suggestion_output],
250
  )
251
 
252
- # Step 2: Generate Voice (Inactive)
253
  with gr.Tab("Step 2: Generate Voice"):
254
- gr.Markdown("""
255
- **Note:** Voice-over generation is currently inactive.
256
- This feature will be available in future updates!
257
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- # Step 3: Generate Music
260
  with gr.Tab("Step 3: Generate Music"):
261
- with gr.Row():
262
- audio_length = gr.Slider(
263
- label="Music Length (tokens)",
264
- minimum=128,
265
- maximum=1024,
266
- step=64,
267
- value=512,
268
- info="Increase tokens for longer audio, but be mindful of inference time."
269
- )
270
  generate_music_button = gr.Button("Generate Music")
271
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
272
 
@@ -276,14 +342,27 @@ with gr.Blocks() as demo:
276
  outputs=[music_output],
277
  )
278
 
279
- # Step 4: Blend Audio (Inactive)
280
  with gr.Tab("Step 4: Blend Audio"):
281
- gr.Markdown("""
282
- **Note:** Audio blending functionality is currently inactive.
283
- This feature will be available in future updates!
284
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- # Footer / Credits
287
  gr.Markdown("""
288
  <hr>
289
  <p style="text-align: center; font-size: 0.9em;">
@@ -298,5 +377,4 @@ with gr.Blocks() as demo:
298
  </a>
299
  """)
300
 
301
- # Launch the Gradio app
302
  demo.launch(debug=True)
 
13
  from dotenv import load_dotenv
14
  import tempfile
15
  import spaces
16
+
17
+ # Coqui TTS
18
  from TTS.api import TTS
 
19
 
20
  # ---------------------------------------------------------------------
21
  # Load Environment Variables
22
  # ---------------------------------------------------------------------
23
  load_dotenv()
24
+ HF_TOKEN = os.getenv("HF_TOKEN") # Adjust if needed
25
 
26
  # ---------------------------------------------------------------------
27
  # Global Model Caches
28
  # ---------------------------------------------------------------------
 
 
29
  LLAMA_PIPELINES = {}
30
  MUSICGEN_MODELS = {}
31
+ TTS_MODELS = {}
32
 
33
  # ---------------------------------------------------------------------
34
  # Helper Functions
 
36
  def get_llama_pipeline(model_id: str, token: str):
37
  """
38
  Returns a cached LLaMA pipeline if available; otherwise, loads it.
 
39
  """
40
  if model_id in LLAMA_PIPELINES:
41
  return LLAMA_PIPELINES[model_id]
42
 
 
43
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
44
  model = AutoModelForCausalLM.from_pretrained(
45
  model_id,
 
53
  return text_pipeline
54
 
55
 
56
+ def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
57
  """
58
  Returns a cached MusicGen model if available; otherwise, loads it.
59
+ Uses the 'large' variant for higher quality outputs.
60
  """
61
  if model_key in MUSICGEN_MODELS:
62
  return MUSICGEN_MODELS[model_key]
63
 
 
64
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
65
  processor = AutoProcessor.from_pretrained(model_key)
66
 
 
71
  return model, processor
72
 
73
 
74
+ def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
75
+ """
76
+ Returns a cached TTS model if available; otherwise, loads it.
77
+ """
78
+ if model_name in TTS_MODELS:
79
+ return TTS_MODELS[model_name]
80
+
81
+ tts_model = TTS(model_name)
82
+ TTS_MODELS[model_name] = tts_model
83
+ return tts_model
84
+
85
+
86
  # ---------------------------------------------------------------------
87
  # Script Generation Function
88
  # ---------------------------------------------------------------------
 
95
  try:
96
  text_pipeline = get_llama_pipeline(model_id, token)
97
 
 
98
  system_prompt = (
99
  "You are an expert radio imaging producer specializing in sound design and music. "
100
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
 
102
  "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
103
  "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
104
  )
 
105
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
106
 
 
107
  with torch.inference_mode():
108
  result = text_pipeline(
109
  combined_prompt,
 
112
  temperature=0.8
113
  )
114
 
 
115
  generated_text = result[0]["generated_text"]
 
 
 
116
  if "Output:" in generated_text:
117
  generated_text = generated_text.split("Output:")[-1].strip()
118
 
119
+ # Default placeholders
120
  voice_script = "No voice-over script found."
121
  sound_design = "No sound design suggestions found."
122
  music_suggestions = "No music suggestions found."
123
 
124
+ # Voice-Over Script
125
  if "Voice-Over Script:" in generated_text:
126
  parts = generated_text.split("Voice-Over Script:")
127
+ voice_script_part = parts[1]
128
+ if "Sound Design Suggestions:" in voice_script_part:
129
+ voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip()
130
+ else:
131
+ voice_script = voice_script_part.strip()
132
 
133
+ # Sound Design
134
  if "Sound Design Suggestions:" in generated_text:
135
  parts = generated_text.split("Sound Design Suggestions:")
136
+ sound_design_part = parts[1]
137
+ if "Music Suggestions:" in sound_design_part:
138
+ sound_design = sound_design_part.split("Music Suggestions:")[0].strip()
139
+ else:
140
+ sound_design = sound_design_part.strip()
141
 
142
+ # Music Suggestions
143
  if "Music Suggestions:" in generated_text:
144
  parts = generated_text.split("Music Suggestions:")
145
+ music_suggestions = parts[1].strip()
 
146
 
147
  return voice_script, sound_design, music_suggestions
148
 
 
151
 
152
 
153
  # ---------------------------------------------------------------------
154
+ # Voice-Over Generation Function
155
  # ---------------------------------------------------------------------
156
  @spaces.GPU(duration=100)
157
+ def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
158
  """
159
+ Generates a voice-over from the provided script using the Coqui TTS model.
160
+ Returns the file path to the generated .wav file.
161
  """
162
  try:
163
+ if not script.strip():
164
+ return "Error: No script provided."
165
+
166
+ tts_model = get_tts_model(tts_model_name)
167
+
168
+ # Generate and save voice
169
+ output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
170
+ tts_model.tts_to_file(text=script, file_path=output_path)
171
+ return output_path
172
+
173
  except Exception as e:
174
+ return f"Error generating voice: {e}"
175
 
176
 
177
  # ---------------------------------------------------------------------
178
+ # Music Generation Function (Using facebook/musicgen-large)
179
  # ---------------------------------------------------------------------
180
  @spaces.GPU(duration=100)
181
  def generate_music(prompt: str, audio_length: int):
182
  """
183
+ Generates music from the 'facebook/musicgen-large' model based on the prompt.
184
  Returns the file path to the generated .wav file.
185
  """
186
  try:
187
+ if not prompt.strip():
188
+ return "Error: No music suggestion provided."
189
+
190
+ model_key = "facebook/musicgen-large"
191
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
192
 
193
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
194
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
195
 
 
196
  with torch.inference_mode():
197
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
198
 
199
  audio_data = outputs[0, 0].cpu().numpy()
 
200
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
201
 
202
+ output_path = f"{tempfile.gettempdir()}/musicgen_large_generated_music.wav"
 
203
  write(output_path, 44100, normalized_audio)
204
 
205
  return output_path
 
209
 
210
 
211
  # ---------------------------------------------------------------------
212
+ # Audio Blending Function with Ducking
213
  # ---------------------------------------------------------------------
214
+ @spaces.GPU(duration=100)
215
+ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10):
216
  """
217
+ Blends two audio files (voice and music). If ducking=True,
218
+ the music is attenuated by 'duck_level' dB while the voice is playing.
219
+ Returns the file path to the blended .wav file.
220
  """
221
  try:
222
+ if not os.path.isfile(voice_path) or not os.path.isfile(music_path):
223
+ return "Error: Missing audio files for blending."
224
+
225
+ voice = AudioSegment.from_wav(voice_path)
226
+ music = AudioSegment.from_wav(music_path)
227
+
228
+ # If the voice is longer than the music, extend music with silence
229
+ if len(voice) > len(music):
230
+ extension = AudioSegment.silent(duration=(len(voice) - len(music)))
231
+ music = music + extension
232
+
233
+ if ducking:
234
+ # Step 1: Reduce music by `duck_level` dB for the portion matching the voice duration
235
+ ducked_music_part = music[:len(voice)] - duck_level
236
+ # Overlay voice on top of the ducked music portion
237
+ voice_overlaid = ducked_music_part.overlay(voice)
238
+
239
+ # Step 2: Keep the rest of the music as-is
240
+ remainder = music[len(voice):]
241
+ final_audio = voice_overlaid + remainder
242
+ else:
243
+ # No ducking, just overlay
244
+ final_audio = music.overlay(voice)
245
+
246
+ output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
247
+ final_audio.export(output_path, format="wav")
248
+ return output_path
249
+
250
  except Exception as e:
251
+ return f"Error blending audio: {e}"
252
 
253
 
254
  # ---------------------------------------------------------------------
 
256
  # ---------------------------------------------------------------------
257
  with gr.Blocks() as demo:
258
  gr.Markdown("""
259
+ # 🎧 AI Promo Studio with MusicGen Large, Voice Over & Audio Blending 🚀
260
+ Welcome to **AI Promo Studio**!
261
+ This pipeline uses **facebook/musicgen-large** for high-quality background music (more resource-intensive).
262
+
263
+ **Workflow**:
264
+ 1. **Generate Script** (via LLaMA)
265
+ 2. **Generate Voice-Over** (via Coqui TTS)
266
+ 3. **Generate Music** (via MusicGen Large)
267
+ 4. **Blend** (Voice + Music) with optional ducking
268
  """)
269
 
270
  with gr.Tabs():
 
300
  outputs=[script_output, sound_design_output, music_suggestion_output],
301
  )
302
 
303
+ # Step 2: Generate Voice
304
  with gr.Tab("Step 2: Generate Voice"):
305
+ gr.Markdown("Generate the voice-over using a Coqui TTS model.")
306
+ selected_tts_model = gr.Dropdown(
307
+ label="TTS Model",
308
+ choices=[
309
+ "tts_models/en/ljspeech/tacotron2-DDC",
310
+ "tts_models/en/ljspeech/vits",
311
+ "tts_models/en/sam/tacotron-DDC",
312
+ ],
313
+ value="tts_models/en/ljspeech/tacotron2-DDC",
314
+ multiselect=False
315
+ )
316
+ generate_voice_button = gr.Button("Generate Voice-Over")
317
+ voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
318
+
319
+ generate_voice_button.click(
320
+ fn=lambda script, tts_model: generate_voice(script, tts_model),
321
+ inputs=[script_output, selected_tts_model],
322
+ outputs=voice_audio_output,
323
+ )
324
 
325
+ # Step 3: Generate Music (MusicGen Large)
326
  with gr.Tab("Step 3: Generate Music"):
327
+ gr.Markdown("Generate a music track with the **MusicGen Large** model.")
328
+ audio_length = gr.Slider(
329
+ label="Music Length (tokens)",
330
+ minimum=128,
331
+ maximum=1024,
332
+ step=64,
333
+ value=512,
334
+ info="Increase tokens for longer audio, but be mindful of inference time."
335
+ )
336
  generate_music_button = gr.Button("Generate Music")
337
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
338
 
 
342
  outputs=[music_output],
343
  )
344
 
345
+ # Step 4: Blend Audio
346
  with gr.Tab("Step 4: Blend Audio"):
347
+ gr.Markdown("Combine voice-over and music, optionally applying ducking.")
348
+ ducking_checkbox = gr.Checkbox(label="Enable Ducking?", value=True)
349
+ duck_level_slider = gr.Slider(
350
+ label="Ducking Level (dB attenuation)",
351
+ minimum=0,
352
+ maximum=20,
353
+ step=1,
354
+ value=10
355
+ )
356
+ blend_button = gr.Button("Blend Voice + Music")
357
+ blended_output = gr.Audio(label="Final Blended Output (WAV)", type="filepath")
358
+
359
+ blend_button.click(
360
+ fn=blend_audio,
361
+ inputs=[voice_audio_output, music_output, ducking_checkbox, duck_level_slider],
362
+ outputs=blended_output
363
+ )
364
 
365
+ # Footer
366
  gr.Markdown("""
367
  <hr>
368
  <p style="text-align: center; font-size: 0.9em;">
 
377
  </a>
378
  """)
379
 
 
380
  demo.launch(debug=True)