owiedotch commited on
Commit
dd53483
1 Parent(s): 5895437

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -7,11 +7,15 @@ from encodec.utils import convert_audio
7
  from encodec.compress import compress_to_file, decompress_from_file
8
  import io
9
 
10
- # Load the Encodec model
11
- model = EncodecModel.encodec_model_48khz()
 
 
 
 
12
  model.set_target_bandwidth(6.0)
13
 
14
- @spaces.GPU
15
  def encode(audio_file_path):
16
  try:
17
  # Load and pre-process the audio waveform
@@ -22,6 +26,7 @@ def encode(audio_file_path):
22
  wav = torch.mean(wav, dim=0, keepdim=True)
23
 
24
  wav = convert_audio(wav, sr, model.sample_rate, model.channels)
 
25
 
26
  # Compress to ecdc file in memory
27
  output = io.BytesIO()
@@ -34,11 +39,11 @@ def encode(audio_file_path):
34
  gr.Warning(f"An error occurred during encoding: {e}")
35
  return None
36
 
37
- @spaces.GPU
38
  def decode(compressed_audio_file):
39
  try:
40
  # Decompress audio
41
- wav, sr = decompress_from_file(compressed_audio_file)
42
 
43
  # Convert the decoded audio to a numpy array for Gradio output
44
  decoded_audio = wav.cpu().numpy()
 
7
  from encodec.compress import compress_to_file, decompress_from_file
8
  import io
9
 
10
+ # Check for CUDA availability and set device
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print(f"Using device: {device}")
13
+
14
+ # Load the Encodec model and move it to the selected device
15
+ model = EncodecModel.encodec_model_48khz().to(device)
16
  model.set_target_bandwidth(6.0)
17
 
18
+ @spaces.GPU # Indicate GPU usage for Spaces environment (if applicable)
19
  def encode(audio_file_path):
20
  try:
21
  # Load and pre-process the audio waveform
 
26
  wav = torch.mean(wav, dim=0, keepdim=True)
27
 
28
  wav = convert_audio(wav, sr, model.sample_rate, model.channels)
29
+ wav = wav.to(device) # Move the input audio to the selected device
30
 
31
  # Compress to ecdc file in memory
32
  output = io.BytesIO()
 
39
  gr.Warning(f"An error occurred during encoding: {e}")
40
  return None
41
 
42
+ @spaces.GPU
43
  def decode(compressed_audio_file):
44
  try:
45
  # Decompress audio
46
+ wav, sr = decompress_from_file(compressed_audio_file, device=device) # Pass the device to decompress_from_file
47
 
48
  # Convert the decoded audio to a numpy array for Gradio output
49
  decoded_audio = wav.cpu().numpy()