owiedotch commited on
Commit
0b50165
1 Parent(s): e4056a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -18,6 +18,7 @@ except Exception as e:
18
  torch_device = torch.device("cpu")
19
 
20
  # Load the AGC model
 
21
  def load_agc_model():
22
  return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
23
 
@@ -29,9 +30,16 @@ def encode_audio(audio_file_path):
29
  # Load the audio file
30
  waveform, sample_rate = torchaudio.load(audio_file_path)
31
 
32
- # Convert to stereo if necessary
33
- if waveform.size(0) == 1:
34
- waveform = waveform.repeat(2, 1)
 
 
 
 
 
 
 
35
 
36
  # Encode the audio
37
  audio = waveform.unsqueeze(0).to(torch_device)
@@ -67,7 +75,7 @@ def decode_audio(encoded_file_path):
67
 
68
  # Save to a temporary WAV file
69
  temp_wav_path = tempfile.mktemp(suffix=".wav")
70
- torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate)
71
  return temp_wav_path
72
 
73
  except Exception as e:
@@ -84,23 +92,23 @@ def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
84
  z = torch.from_numpy(z_numpy).to(torch_device)
85
 
86
  # Decode the audio in chunks
87
- chunk_size = 16000 # 1 second of audio at 16kHz
88
- sample_rate = 16000 # AGC model's output sample rate
89
  with torch.no_grad():
90
  for i in range(0, z.shape[2], chunk_size):
91
  z_chunk = z[:, :, i:i+chunk_size]
92
  audio_chunk = agc.decode(z_chunk)
93
- # Convert to numpy array and transpose
94
- audio_data = audio_chunk.squeeze(0).cpu().numpy().T
95
  yield (sample_rate, audio_data)
96
 
97
  except Exception as e:
98
  print(f"Streaming decoding error: {e}")
99
- yield (sample_rate, np.zeros((2, chunk_size), dtype=np.float32)) # Return silence in case of error
100
 
101
  # Gradio Interface
102
  with gr.Blocks() as demo:
103
- gr.Markdown("## Audio Compression with AGC (GPU/CPU)")
104
 
105
  with gr.Tab("Encode"):
106
  input_audio = gr.Audio(label="Input Audio", type="filepath")
 
18
  torch_device = torch.device("cpu")
19
 
20
  # Load the AGC model
21
+ @spaces.GPU(duration=180)
22
  def load_agc_model():
23
  return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
24
 
 
30
  # Load the audio file
31
  waveform, sample_rate = torchaudio.load(audio_file_path)
32
 
33
+ # Resample to 32kHz if necessary
34
+ if sample_rate != 32000:
35
+ resampler = torchaudio.transforms.Resample(sample_rate, 32000)
36
+ waveform = resampler(waveform)
37
+
38
+ # Convert to 32 channels if necessary
39
+ if waveform.size(0) < 32:
40
+ waveform = waveform.repeat(32, 1)[:32, :]
41
+ elif waveform.size(0) > 32:
42
+ waveform = waveform[:32, :]
43
 
44
  # Encode the audio
45
  audio = waveform.unsqueeze(0).to(torch_device)
 
75
 
76
  # Save to a temporary WAV file
77
  temp_wav_path = tempfile.mktemp(suffix=".wav")
78
+ torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), 32000)
79
  return temp_wav_path
80
 
81
  except Exception as e:
 
92
  z = torch.from_numpy(z_numpy).to(torch_device)
93
 
94
  # Decode the audio in chunks
95
+ chunk_size = 32000 # 1 second of audio at 32kHz
96
+ sample_rate = 32000 # AGC model's output sample rate
97
  with torch.no_grad():
98
  for i in range(0, z.shape[2], chunk_size):
99
  z_chunk = z[:, :, i:i+chunk_size]
100
  audio_chunk = agc.decode(z_chunk)
101
+ # Convert to numpy array (32 channels)
102
+ audio_data = audio_chunk.squeeze(0).cpu().numpy()
103
  yield (sample_rate, audio_data)
104
 
105
  except Exception as e:
106
  print(f"Streaming decoding error: {e}")
107
+ yield (sample_rate, np.zeros((32, chunk_size), dtype=np.float32)) # Return silence in case of error
108
 
109
  # Gradio Interface
110
  with gr.Blocks() as demo:
111
+ gr.Markdown("## Audio Compression with AGC (GPU/CPU) - 32 channels, 32kHz")
112
 
113
  with gr.Tab("Encode"):
114
  input_audio = gr.Audio(label="Input Audio", type="filepath")