Update app.py
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ except Exception as e:
|
|
18 |
torch_device = torch.device("cpu")
|
19 |
|
20 |
# Load the AGC model
|
|
|
21 |
def load_agc_model():
|
22 |
return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
|
23 |
|
@@ -29,9 +30,16 @@ def encode_audio(audio_file_path):
|
|
29 |
# Load the audio file
|
30 |
waveform, sample_rate = torchaudio.load(audio_file_path)
|
31 |
|
32 |
-
#
|
33 |
-
if
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Encode the audio
|
37 |
audio = waveform.unsqueeze(0).to(torch_device)
|
@@ -67,7 +75,7 @@ def decode_audio(encoded_file_path):
|
|
67 |
|
68 |
# Save to a temporary WAV file
|
69 |
temp_wav_path = tempfile.mktemp(suffix=".wav")
|
70 |
-
torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(),
|
71 |
return temp_wav_path
|
72 |
|
73 |
except Exception as e:
|
@@ -84,23 +92,23 @@ def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
|
|
84 |
z = torch.from_numpy(z_numpy).to(torch_device)
|
85 |
|
86 |
# Decode the audio in chunks
|
87 |
-
chunk_size =
|
88 |
-
sample_rate =
|
89 |
with torch.no_grad():
|
90 |
for i in range(0, z.shape[2], chunk_size):
|
91 |
z_chunk = z[:, :, i:i+chunk_size]
|
92 |
audio_chunk = agc.decode(z_chunk)
|
93 |
-
# Convert to numpy array
|
94 |
-
audio_data = audio_chunk.squeeze(0).cpu().numpy()
|
95 |
yield (sample_rate, audio_data)
|
96 |
|
97 |
except Exception as e:
|
98 |
print(f"Streaming decoding error: {e}")
|
99 |
-
yield (sample_rate, np.zeros((
|
100 |
|
101 |
# Gradio Interface
|
102 |
with gr.Blocks() as demo:
|
103 |
-
gr.Markdown("## Audio Compression with AGC (GPU/CPU)")
|
104 |
|
105 |
with gr.Tab("Encode"):
|
106 |
input_audio = gr.Audio(label="Input Audio", type="filepath")
|
|
|
18 |
torch_device = torch.device("cpu")
|
19 |
|
20 |
# Load the AGC model
|
21 |
+
@spaces.GPU(duration=180)
|
22 |
def load_agc_model():
|
23 |
return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
|
24 |
|
|
|
30 |
# Load the audio file
|
31 |
waveform, sample_rate = torchaudio.load(audio_file_path)
|
32 |
|
33 |
+
# Resample to 32kHz if necessary
|
34 |
+
if sample_rate != 32000:
|
35 |
+
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
|
36 |
+
waveform = resampler(waveform)
|
37 |
+
|
38 |
+
# Convert to 32 channels if necessary
|
39 |
+
if waveform.size(0) < 32:
|
40 |
+
waveform = waveform.repeat(32, 1)[:32, :]
|
41 |
+
elif waveform.size(0) > 32:
|
42 |
+
waveform = waveform[:32, :]
|
43 |
|
44 |
# Encode the audio
|
45 |
audio = waveform.unsqueeze(0).to(torch_device)
|
|
|
75 |
|
76 |
# Save to a temporary WAV file
|
77 |
temp_wav_path = tempfile.mktemp(suffix=".wav")
|
78 |
+
torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), 32000)
|
79 |
return temp_wav_path
|
80 |
|
81 |
except Exception as e:
|
|
|
92 |
z = torch.from_numpy(z_numpy).to(torch_device)
|
93 |
|
94 |
# Decode the audio in chunks
|
95 |
+
chunk_size = 32000 # 1 second of audio at 32kHz
|
96 |
+
sample_rate = 32000 # AGC model's output sample rate
|
97 |
with torch.no_grad():
|
98 |
for i in range(0, z.shape[2], chunk_size):
|
99 |
z_chunk = z[:, :, i:i+chunk_size]
|
100 |
audio_chunk = agc.decode(z_chunk)
|
101 |
+
# Convert to numpy array (32 channels)
|
102 |
+
audio_data = audio_chunk.squeeze(0).cpu().numpy()
|
103 |
yield (sample_rate, audio_data)
|
104 |
|
105 |
except Exception as e:
|
106 |
print(f"Streaming decoding error: {e}")
|
107 |
+
yield (sample_rate, np.zeros((32, chunk_size), dtype=np.float32)) # Return silence in case of error
|
108 |
|
109 |
# Gradio Interface
|
110 |
with gr.Blocks() as demo:
|
111 |
+
gr.Markdown("## Audio Compression with AGC (GPU/CPU) - 32 channels, 32kHz")
|
112 |
|
113 |
with gr.Tab("Encode"):
|
114 |
input_audio = gr.Audio(label="Input Audio", type="filepath")
|