Bils commited on
Commit
1c1b50f
·
verified ·
1 Parent(s): 53f90b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -16
app.py CHANGED
@@ -1,17 +1,18 @@
1
  import gradio as gr
2
  import os
3
  import torch
 
4
  from transformers import (
5
- AutoTokenizer,
6
- AutoModelForCausalLM,
7
  pipeline,
8
- AutoProcessor,
9
- MusicgenForConditionalGeneration
10
  )
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
- import spaces # Assumes Hugging Face Spaces library supports `@spaces.GPU`
15
 
16
  # Load environment variables (e.g., Hugging Face token)
17
  load_dotenv()
@@ -22,10 +23,31 @@ llama_pipeline = None
22
  musicgen_model = None
23
  musicgen_processor = None
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # ---------------------------------------------------------------------
26
  # Load Llama 3 Model with Zero GPU (Lazy Loading)
27
  # ---------------------------------------------------------------------
28
- @spaces.GPU(duration=300) # Increased duration to 300 seconds
29
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
  global llama_pipeline
31
  if llama_pipeline is None:
@@ -33,13 +55,7 @@ def load_llama_pipeline_zero_gpu(model_id: str, token: str):
33
  print("Starting model loading...")
34
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
35
  print("Tokenizer loaded.")
36
- model = AutoModelForCausalLM.from_pretrained(
37
- model_id,
38
- use_auth_token=token,
39
- torch_dtype=torch.float16,
40
- device_map="auto", # Automatically handles GPU allocation
41
- trust_remote_code=True
42
- )
43
  print("Model loaded. Initializing pipeline...")
44
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
45
  print("Pipeline initialized successfully.")
@@ -66,7 +82,7 @@ def generate_script(user_input: str, pipeline_llama):
66
  # ---------------------------------------------------------------------
67
  # Load MusicGen Model (Lazy Loading)
68
  # ---------------------------------------------------------------------
69
- @spaces.GPU(duration=300)
70
  def load_musicgen_model():
71
  global musicgen_model, musicgen_processor
72
  if musicgen_model is None or musicgen_processor is None:
@@ -83,7 +99,7 @@ def load_musicgen_model():
83
  # ---------------------------------------------------------------------
84
  # Generate Audio
85
  # ---------------------------------------------------------------------
86
- @spaces.GPU(duration=300)
87
  def generate_audio(prompt: str, audio_length: int):
88
  global musicgen_model, musicgen_processor
89
  if musicgen_model is None or musicgen_processor is None:
@@ -132,7 +148,7 @@ with gr.Blocks() as demo:
132
 
133
  with gr.Row():
134
  user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
135
- llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
136
  audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
137
 
138
  with gr.Row():
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ import time
5
  from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
  pipeline,
9
+ AutoProcessor,
10
+ MusicgenForConditionalGeneration,
11
  )
12
  from scipy.io.wavfile import write
13
  import tempfile
14
  from dotenv import load_dotenv
15
+ import spaces # Hugging Face Spaces library for ZeroGPU support
16
 
17
  # Load environment variables (e.g., Hugging Face token)
18
  load_dotenv()
 
23
  musicgen_model = None
24
  musicgen_processor = None
25
 
26
+ # ---------------------------------------------------------------------
27
+ # Helper: Safe Model Loader with Retry Logic
28
+ # ---------------------------------------------------------------------
29
+ def safe_load_model(model_id, token, retries=3, delay=5):
30
+ for attempt in range(retries):
31
+ try:
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_id,
34
+ use_auth_token=token,
35
+ torch_dtype=torch.float16,
36
+ device_map="auto",
37
+ trust_remote_code=True,
38
+ offload_folder="/tmp", # Stream shards
39
+ cache_dir="/tmp" # Cache directory for shard downloads
40
+ )
41
+ return model
42
+ except Exception as e:
43
+ print(f"Attempt {attempt + 1} failed: {e}")
44
+ time.sleep(delay)
45
+ raise RuntimeError(f"Failed to load model {model_id} after {retries} attempts")
46
+
47
  # ---------------------------------------------------------------------
48
  # Load Llama 3 Model with Zero GPU (Lazy Loading)
49
  # ---------------------------------------------------------------------
50
+ @spaces.GPU(duration=600) # Increased duration to handle large models
51
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
52
  global llama_pipeline
53
  if llama_pipeline is None:
 
55
  print("Starting model loading...")
56
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
57
  print("Tokenizer loaded.")
58
+ model = safe_load_model(model_id, token)
 
 
 
 
 
 
59
  print("Model loaded. Initializing pipeline...")
60
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
61
  print("Pipeline initialized successfully.")
 
82
  # ---------------------------------------------------------------------
83
  # Load MusicGen Model (Lazy Loading)
84
  # ---------------------------------------------------------------------
85
+ @spaces.GPU(duration=600)
86
  def load_musicgen_model():
87
  global musicgen_model, musicgen_processor
88
  if musicgen_model is None or musicgen_processor is None:
 
99
  # ---------------------------------------------------------------------
100
  # Generate Audio
101
  # ---------------------------------------------------------------------
102
+ @spaces.GPU(duration=600)
103
  def generate_audio(prompt: str, audio_length: int):
104
  global musicgen_model, musicgen_processor
105
  if musicgen_model is None or musicgen_processor is None:
 
148
 
149
  with gr.Row():
150
  user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
151
+ llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B") # Using a smaller model for better compatibility
152
  audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
153
 
154
  with gr.Row():