dac / app.py
owiedotch's picture
Update app.py
ab12e78 verified
raw
history blame
No virus
2.45 kB
import gradio as gr
import torch
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor
import spaces
from encodec import compress, decompress
import io
# Load the Encodec model and processor
model = EncodecModel.from_pretrained("facebook/encodec_48khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
@spaces.GPU
def encode(audio_file_path):
try:
# Open the audio file
with open(audio_file_path, "rb") as audio_file:
# Load and preprocess the audio
audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file_path, split="train")[0]["audio"]
wav = torch.tensor(audio_sample).unsqueeze(0)
# Compress to ecdc
compressed_audio = compress(model, wav)
# Save compressed audio to BytesIO
output = io.BytesIO(compressed_audio)
output.seek(0)
return output
except Exception as e:
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(compressed_audio_file):
try:
# Load compressed audio
compressed_audio = compressed_audio_file.read()
# Decompress audio
wav, sr = decompress(compressed_audio)
# Convert the decoded audio to a numpy array for Gradio output
decoded_audio = wav.cpu().numpy()
return decoded_audio
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with Encodec</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.File(label="Compressed Audio (.ecdc)")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
compressed_input = gr.File(label="Compressed Audio (.ecdc)")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.queue().launch()