Update app.py
Browse files
app.py
CHANGED
@@ -3,48 +3,46 @@ import torch
|
|
3 |
from datasets import load_dataset, Audio
|
4 |
from transformers import EncodecModel, AutoProcessor
|
5 |
import spaces
|
|
|
|
|
6 |
|
7 |
# Load the Encodec model and processor
|
8 |
model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
9 |
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
10 |
|
11 |
@spaces.GPU
|
12 |
-
def encode(audio_file_path):
|
13 |
try:
|
14 |
# Open the audio file
|
15 |
with open(audio_file_path, "rb") as audio_file:
|
16 |
# Load and preprocess the audio
|
17 |
audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file_path, split="train")[0]["audio"]
|
18 |
-
|
19 |
|
20 |
-
#
|
21 |
-
|
22 |
-
encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
|
23 |
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
return {"codes": audio_codes.tolist(), "scales": audio_scales.tolist()}
|
30 |
|
31 |
except Exception as e:
|
32 |
gr.Warning(f"An error occurred during encoding: {e}")
|
33 |
return None
|
34 |
|
35 |
@spaces.GPU
|
36 |
-
def decode(
|
37 |
try:
|
38 |
-
#
|
39 |
-
|
40 |
-
audio_scales = torch.tensor(encoded_data["scales"])
|
41 |
-
|
42 |
-
# Decode the audio
|
43 |
-
with torch.no_grad():
|
44 |
-
audio_values = model.decode(audio_codes, audio_scales)[0]
|
45 |
|
|
|
|
|
|
|
46 |
# Convert the decoded audio to a numpy array for Gradio output
|
47 |
-
decoded_audio =
|
48 |
|
49 |
return decoded_audio
|
50 |
|
@@ -61,17 +59,17 @@ with gr.Blocks() as demo:
|
|
61 |
audio_input = gr.Audio(type="filepath", label="Input Audio")
|
62 |
encode_button = gr.Button("Encode", variant="primary")
|
63 |
with gr.Row():
|
64 |
-
encoded_output = gr.
|
65 |
|
66 |
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
|
67 |
|
68 |
with gr.Tab("Decode"):
|
69 |
with gr.Row():
|
70 |
-
|
71 |
decode_button = gr.Button("Decode", variant="primary")
|
72 |
with gr.Row():
|
73 |
decoded_output = gr.Audio(label="Decompressed Audio")
|
74 |
|
75 |
-
decode_button.click(decode, inputs=
|
76 |
|
77 |
demo.queue().launch()
|
|
|
3 |
from datasets import load_dataset, Audio
|
4 |
from transformers import EncodecModel, AutoProcessor
|
5 |
import spaces
|
6 |
+
from encodec import compress, decompress
|
7 |
+
import io
|
8 |
|
9 |
# Load the Encodec model and processor
|
10 |
model = EncodecModel.from_pretrained("facebook/encodec_48khz")
|
11 |
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
|
12 |
|
13 |
@spaces.GPU
|
14 |
+
def encode(audio_file_path):
|
15 |
try:
|
16 |
# Open the audio file
|
17 |
with open(audio_file_path, "rb") as audio_file:
|
18 |
# Load and preprocess the audio
|
19 |
audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file_path, split="train")[0]["audio"]
|
20 |
+
wav = torch.tensor(audio_sample).unsqueeze(0)
|
21 |
|
22 |
+
# Compress to ecdc
|
23 |
+
compressed_audio = compress(model, wav)
|
|
|
24 |
|
25 |
+
# Save compressed audio to BytesIO
|
26 |
+
output = io.BytesIO(compressed_audio)
|
27 |
+
output.seek(0)
|
28 |
+
|
29 |
+
return output
|
|
|
30 |
|
31 |
except Exception as e:
|
32 |
gr.Warning(f"An error occurred during encoding: {e}")
|
33 |
return None
|
34 |
|
35 |
@spaces.GPU
|
36 |
+
def decode(compressed_audio_file):
|
37 |
try:
|
38 |
+
# Load compressed audio
|
39 |
+
compressed_audio = compressed_audio_file.read()
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
# Decompress audio
|
42 |
+
wav, sr = decompress(compressed_audio)
|
43 |
+
|
44 |
# Convert the decoded audio to a numpy array for Gradio output
|
45 |
+
decoded_audio = wav.cpu().numpy()
|
46 |
|
47 |
return decoded_audio
|
48 |
|
|
|
59 |
audio_input = gr.Audio(type="filepath", label="Input Audio")
|
60 |
encode_button = gr.Button("Encode", variant="primary")
|
61 |
with gr.Row():
|
62 |
+
encoded_output = gr.File(label="Compressed Audio (.ecdc)")
|
63 |
|
64 |
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
|
65 |
|
66 |
with gr.Tab("Decode"):
|
67 |
with gr.Row():
|
68 |
+
compressed_input = gr.File(label="Compressed Audio (.ecdc)")
|
69 |
decode_button = gr.Button("Decode", variant="primary")
|
70 |
with gr.Row():
|
71 |
decoded_output = gr.Audio(label="Decompressed Audio")
|
72 |
|
73 |
+
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
|
74 |
|
75 |
demo.queue().launch()
|