Bils commited on
Commit
db8ba25
·
verified ·
1 Parent(s): 07c07fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -74
app.py CHANGED
@@ -6,7 +6,7 @@ from transformers import (
6
  AutoModelForCausalLM,
7
  pipeline,
8
  AutoProcessor,
9
- MusicgenForConditionalGeneration
10
  )
11
  from scipy.io.wavfile import write
12
  import tempfile
@@ -17,81 +17,44 @@ import spaces # Assumes Hugging Face Spaces library supports `@spaces.GPU`
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
- # Globals for lazy loading
21
- llama_pipeline = None
22
- musicgen_model = None
23
- musicgen_processor = None
24
 
25
  # ---------------------------------------------------------------------
26
- # Load Llama 3 Model with Zero GPU (Lazy Loading) - Smaller Model
27
  # ---------------------------------------------------------------------
28
  @spaces.GPU(duration=300) # Adjust GPU allocation duration
29
- def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
- global llama_pipeline
31
- if llama_pipeline is None:
32
- try:
33
- print("Starting model loading...")
34
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
35
- print("Tokenizer loaded.")
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_id,
38
- use_auth_token=token,
39
- torch_dtype=torch.float16,
40
- device_map="auto", # Automatically handles GPU allocation
41
- trust_remote_code=True
42
- )
43
- print("Model loaded. Initializing pipeline...")
44
- llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
45
- print("Pipeline initialized successfully.")
46
- except Exception as e:
47
- print(f"Error loading Llama pipeline: {e}")
48
- return str(e)
49
- return llama_pipeline
50
-
51
- # ---------------------------------------------------------------------
52
- # Generate Radio Script
53
- # ---------------------------------------------------------------------
54
- def generate_script(user_input: str, pipeline_llama):
55
  try:
 
 
 
 
 
 
 
 
 
 
56
  system_prompt = (
57
  "You are a top-tier radio imaging producer using Llama 3. "
58
  "Take the user's concept and craft a short, creative promo script."
59
  )
60
- combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
61
- result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
62
- return result[0]['generated_text'].split("Refined script:")[-1].strip()
63
  except Exception as e:
64
  return f"Error generating script: {e}"
65
 
66
- # ---------------------------------------------------------------------
67
- # Load MusicGen Model (Lazy Loading)
68
- # ---------------------------------------------------------------------
69
- @spaces.GPU(duration=300)
70
- def load_musicgen_model():
71
- global musicgen_model, musicgen_processor
72
- if musicgen_model is None or musicgen_processor is None:
73
- try:
74
- print("Loading MusicGen model...")
75
- musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
76
- musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
77
- print("MusicGen model loaded successfully.")
78
- except Exception as e:
79
- print(f"Error loading MusicGen model: {e}")
80
- return None, str(e)
81
- return musicgen_model, musicgen_processor
82
 
83
  # ---------------------------------------------------------------------
84
- # Generate Audio
85
  # ---------------------------------------------------------------------
86
  @spaces.GPU(duration=300)
87
  def generate_audio(prompt: str, audio_length: int):
88
- global musicgen_model, musicgen_processor
89
- if musicgen_model is None or musicgen_processor is None:
90
- musicgen_model, musicgen_processor = load_musicgen_model()
91
- if isinstance(musicgen_model, str):
92
- return musicgen_model
93
  try:
94
- musicgen_model.to("cuda") # Move the model to GPU
 
 
 
95
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
96
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
97
  musicgen_model.to("cpu") # Return the model to CPU
@@ -106,21 +69,17 @@ def generate_audio(prompt: str, audio_length: int):
106
  except Exception as e:
107
  return f"Error generating audio: {e}"
108
 
 
109
  # ---------------------------------------------------------------------
110
  # Gradio Interface
111
  # ---------------------------------------------------------------------
112
- def radio_imaging_app(user_prompt, llama_model_id, audio_length):
113
- # Load Llama 3 Pipeline with Zero GPU
114
- pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
115
- if isinstance(pipeline_llama, str):
116
- return pipeline_llama, None
117
 
118
- # Generate Script
119
- script = generate_script(user_prompt, pipeline_llama)
120
 
121
- # Generate Audio
122
- audio_data = generate_audio(script, audio_length)
123
- return script, audio_data
124
 
125
  # ---------------------------------------------------------------------
126
  # Interface
@@ -129,8 +88,13 @@ with gr.Blocks() as demo:
129
  gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
130
 
131
  with gr.Row():
132
- user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
133
- llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B-Instruct") # Smaller Model
 
 
 
 
 
134
  audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
135
 
136
  generate_script_button = gr.Button("Generate Script")
@@ -139,15 +103,15 @@ with gr.Blocks() as demo:
139
  audio_output = gr.Audio(label="Generated Audio", type="filepath")
140
 
141
  generate_script_button.click(
142
- fn=lambda user_prompt, llama_model_id: radio_imaging_app(user_prompt, llama_model_id, None)[0],
143
  inputs=[user_prompt, llama_model_id],
144
- outputs=script_output
145
  )
146
 
147
  generate_audio_button.click(
148
- fn=lambda script_output, audio_length: generate_audio(script_output, audio_length),
149
  inputs=[script_output, audio_length],
150
- outputs=audio_output
151
  )
152
 
153
  # ---------------------------------------------------------------------
 
6
  AutoModelForCausalLM,
7
  pipeline,
8
  AutoProcessor,
9
+ MusicgenForConditionalGeneration,
10
  )
11
  from scipy.io.wavfile import write
12
  import tempfile
 
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
 
 
 
 
20
 
21
  # ---------------------------------------------------------------------
22
+ # Load Llama 3 Pipeline with Zero GPU (Encapsulated)
23
  # ---------------------------------------------------------------------
24
  @spaces.GPU(duration=300) # Adjust GPU allocation duration
25
+ def generate_script(user_prompt: str, model_id: str, token: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ use_auth_token=token,
31
+ torch_dtype=torch.float16,
32
+ device_map="auto",
33
+ trust_remote_code=True,
34
+ )
35
+ llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
36
+
37
  system_prompt = (
38
  "You are a top-tier radio imaging producer using Llama 3. "
39
  "Take the user's concept and craft a short, creative promo script."
40
  )
41
+ combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nRefined script:"
42
+ result = llama_pipeline(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
43
+ return result[0]["generated_text"].split("Refined script:")[-1].strip()
44
  except Exception as e:
45
  return f"Error generating script: {e}"
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  # ---------------------------------------------------------------------
49
+ # Load MusicGen Model (Encapsulated)
50
  # ---------------------------------------------------------------------
51
  @spaces.GPU(duration=300)
52
  def generate_audio(prompt: str, audio_length: int):
 
 
 
 
 
53
  try:
54
+ musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
55
+ musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
56
+
57
+ musicgen_model.to("cuda")
58
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
59
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
60
  musicgen_model.to("cpu") # Return the model to CPU
 
69
  except Exception as e:
70
  return f"Error generating audio: {e}"
71
 
72
+
73
  # ---------------------------------------------------------------------
74
  # Gradio Interface
75
  # ---------------------------------------------------------------------
76
+ def interface_generate_script(user_prompt, llama_model_id):
77
+ return generate_script(user_prompt, llama_model_id, hf_token)
 
 
 
78
 
 
 
79
 
80
+ def interface_generate_audio(script, audio_length):
81
+ return generate_audio(script, audio_length)
82
+
83
 
84
  # ---------------------------------------------------------------------
85
  # Interface
 
88
  gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
89
 
90
  with gr.Row():
91
+ user_prompt = gr.Textbox(
92
+ label="Enter your promo idea",
93
+ placeholder="E.g., A 15-second hype jingle for a morning talk show.",
94
+ )
95
+ llama_model_id = gr.Textbox(
96
+ label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B-Instruct"
97
+ )
98
  audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
99
 
100
  generate_script_button = gr.Button("Generate Script")
 
103
  audio_output = gr.Audio(label="Generated Audio", type="filepath")
104
 
105
  generate_script_button.click(
106
+ fn=interface_generate_script,
107
  inputs=[user_prompt, llama_model_id],
108
+ outputs=script_output,
109
  )
110
 
111
  generate_audio_button.click(
112
+ fn=interface_generate_audio,
113
  inputs=[script_output, audio_length],
114
+ outputs=audio_output,
115
  )
116
 
117
  # ---------------------------------------------------------------------