Bils commited on
Commit
0ef49a6
·
verified ·
1 Parent(s): 653eb14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -205,7 +205,9 @@ def generate_music(prompt: str, audio_length: int):
205
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
206
 
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
- inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
 
 
209
 
210
  with torch.inference_mode():
211
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
 
205
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
206
 
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
+ # Process the input and move each tensor to the proper device
209
+ inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
210
+ inputs = {k: v.to(device) for k, v in inputs.items()}
211
 
212
  with torch.inference_mode():
213
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)