Bils commited on
Commit
07c07fa
·
verified ·
1 Parent(s): 5080bd7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -48
app.py CHANGED
@@ -6,7 +6,7 @@ from transformers import (
6
  AutoModelForCausalLM,
7
  pipeline,
8
  AutoProcessor,
9
- MusicgenForConditionalGeneration,
10
  )
11
  from scipy.io.wavfile import write
12
  import tempfile
@@ -23,23 +23,28 @@ musicgen_model = None
23
  musicgen_processor = None
24
 
25
  # ---------------------------------------------------------------------
26
- # Load Llama 3 Model with Zero GPU (Lazy Loading)
27
  # ---------------------------------------------------------------------
28
- @spaces.GPU(duration=300)
29
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
  global llama_pipeline
31
  if llama_pipeline is None:
32
  try:
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
  use_auth_token=token,
37
  torch_dtype=torch.float16,
38
  device_map="auto", # Automatically handles GPU allocation
39
- trust_remote_code=True,
40
  )
 
41
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
42
  except Exception as e:
 
43
  return str(e)
44
  return llama_pipeline
45
 
@@ -54,7 +59,7 @@ def generate_script(user_input: str, pipeline_llama):
54
  )
55
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
56
  result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
57
- return result[0]["generated_text"].split("Refined script:")[-1].strip()
58
  except Exception as e:
59
  return f"Error generating script: {e}"
60
 
@@ -66,9 +71,12 @@ def load_musicgen_model():
66
  global musicgen_model, musicgen_processor
67
  if musicgen_model is None or musicgen_processor is None:
68
  try:
 
69
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
70
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
 
71
  except Exception as e:
 
72
  return None, str(e)
73
  return musicgen_model, musicgen_processor
74
 
@@ -101,19 +109,18 @@ def generate_audio(prompt: str, audio_length: int):
101
  # ---------------------------------------------------------------------
102
  # Gradio Interface
103
  # ---------------------------------------------------------------------
104
- def radio_imaging_app(user_prompt, llama_model_id):
105
  # Load Llama 3 Pipeline with Zero GPU
106
  pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
107
  if isinstance(pipeline_llama, str):
108
- return pipeline_llama
109
 
110
  # Generate Script
111
- return generate_script(user_prompt, pipeline_llama)
112
-
113
-
114
- def generate_audio_from_script(script, audio_length):
115
- return generate_audio(script, audio_length)
116
 
 
 
 
117
 
118
  # ---------------------------------------------------------------------
119
  # Interface
@@ -121,42 +128,27 @@ def generate_audio_from_script(script, audio_length):
121
  with gr.Blocks() as demo:
122
  gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
123
 
124
- with gr.Tab("Step 1: Generate Promo Script"):
125
- with gr.Row():
126
- user_prompt = gr.Textbox(
127
- label="Enter Your Promo Idea",
128
- placeholder="E.g., A 15-second hype jingle for a morning talk show.",
129
- )
130
- llama_model_id = gr.Textbox(
131
- label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B"
132
- )
133
-
134
- generate_script_button = gr.Button("Generate Script")
135
- script_output = gr.Textbox(label="Generated Promo Script", interactive=False)
136
-
137
- generate_script_button.click(
138
- fn=radio_imaging_app,
139
- inputs=[user_prompt, llama_model_id],
140
- outputs=script_output,
141
- )
142
-
143
- with gr.Tab("Step 2: Generate Audio"):
144
- with gr.Row():
145
- audio_length = gr.Slider(
146
- label="Audio Length (tokens)",
147
- minimum=128,
148
- maximum=1024,
149
- step=64,
150
- value=512,
151
- )
152
- generate_audio_button = gr.Button("Generate Audio")
153
- audio_output = gr.Audio(label="Generated Audio", type="filepath")
154
-
155
- generate_audio_button.click(
156
- fn=generate_audio_from_script,
157
- inputs=[script_output, audio_length],
158
- outputs=audio_output,
159
- )
160
 
161
  # ---------------------------------------------------------------------
162
  # Launch App
 
6
  AutoModelForCausalLM,
7
  pipeline,
8
  AutoProcessor,
9
+ MusicgenForConditionalGeneration
10
  )
11
  from scipy.io.wavfile import write
12
  import tempfile
 
23
  musicgen_processor = None
24
 
25
  # ---------------------------------------------------------------------
26
+ # Load Llama 3 Model with Zero GPU (Lazy Loading) - Smaller Model
27
  # ---------------------------------------------------------------------
28
+ @spaces.GPU(duration=300) # Adjust GPU allocation duration
29
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
  global llama_pipeline
31
  if llama_pipeline is None:
32
  try:
33
+ print("Starting model loading...")
34
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
35
+ print("Tokenizer loaded.")
36
  model = AutoModelForCausalLM.from_pretrained(
37
  model_id,
38
  use_auth_token=token,
39
  torch_dtype=torch.float16,
40
  device_map="auto", # Automatically handles GPU allocation
41
+ trust_remote_code=True
42
  )
43
+ print("Model loaded. Initializing pipeline...")
44
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
45
+ print("Pipeline initialized successfully.")
46
  except Exception as e:
47
+ print(f"Error loading Llama pipeline: {e}")
48
  return str(e)
49
  return llama_pipeline
50
 
 
59
  )
60
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
61
  result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
62
+ return result[0]['generated_text'].split("Refined script:")[-1].strip()
63
  except Exception as e:
64
  return f"Error generating script: {e}"
65
 
 
71
  global musicgen_model, musicgen_processor
72
  if musicgen_model is None or musicgen_processor is None:
73
  try:
74
+ print("Loading MusicGen model...")
75
  musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
76
  musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
77
+ print("MusicGen model loaded successfully.")
78
  except Exception as e:
79
+ print(f"Error loading MusicGen model: {e}")
80
  return None, str(e)
81
  return musicgen_model, musicgen_processor
82
 
 
109
  # ---------------------------------------------------------------------
110
  # Gradio Interface
111
  # ---------------------------------------------------------------------
112
+ def radio_imaging_app(user_prompt, llama_model_id, audio_length):
113
  # Load Llama 3 Pipeline with Zero GPU
114
  pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
115
  if isinstance(pipeline_llama, str):
116
+ return pipeline_llama, None
117
 
118
  # Generate Script
119
+ script = generate_script(user_prompt, pipeline_llama)
 
 
 
 
120
 
121
+ # Generate Audio
122
+ audio_data = generate_audio(script, audio_length)
123
+ return script, audio_data
124
 
125
  # ---------------------------------------------------------------------
126
  # Interface
 
128
  with gr.Blocks() as demo:
129
  gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
130
 
131
+ with gr.Row():
132
+ user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
133
+ llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B-Instruct") # Smaller Model
134
+ audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
135
+
136
+ generate_script_button = gr.Button("Generate Script")
137
+ generate_audio_button = gr.Button("Generate Audio")
138
+ script_output = gr.Textbox(label="Generated Script")
139
+ audio_output = gr.Audio(label="Generated Audio", type="filepath")
140
+
141
+ generate_script_button.click(
142
+ fn=lambda user_prompt, llama_model_id: radio_imaging_app(user_prompt, llama_model_id, None)[0],
143
+ inputs=[user_prompt, llama_model_id],
144
+ outputs=script_output
145
+ )
146
+
147
+ generate_audio_button.click(
148
+ fn=lambda script_output, audio_length: generate_audio(script_output, audio_length),
149
+ inputs=[script_output, audio_length],
150
+ outputs=audio_output
151
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  # ---------------------------------------------------------------------
154
  # Launch App