|
import gradio as gr |
|
import torch |
|
from datasets import load_dataset, Audio |
|
from transformers import EncodecModel, AutoProcessor |
|
import spaces |
|
|
|
|
|
model = EncodecModel.from_pretrained("facebook/encodec_48khz") |
|
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz") |
|
|
|
@spaces.GPU |
|
def encode(audio_file): |
|
try: |
|
|
|
audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file.name, split="train")[0]["audio"] |
|
inputs = processor(raw_audio=audio_sample, sampling_rate=sampling_rate, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"]) |
|
|
|
|
|
audio_codes = encoder_outputs.audio_codes |
|
audio_scales = encoder_outputs.audio_scales |
|
|
|
|
|
return {"codes": audio_codes.tolist(), "scales": audio_scales.tolist()} |
|
|
|
except Exception as e: |
|
gr.Warning(f"An error occurred during encoding: {e}") |
|
return None |
|
|
|
@spaces.GPU |
|
def decode(encoded_data): |
|
try: |
|
|
|
audio_codes = torch.tensor(encoded_data["codes"]) |
|
audio_scales = torch.tensor(encoded_data["scales"]) |
|
|
|
|
|
with torch.no_grad(): |
|
audio_values = model.decode(audio_codes, audio_scales)[0] |
|
|
|
|
|
decoded_audio = audio_values.cpu().numpy() |
|
|
|
return decoded_audio |
|
|
|
except Exception as e: |
|
gr.Warning(f"An error occurred during decoding: {e}") |
|
return None |
|
|
|
|
|
with gr.Blocks(css=".gradio-container {background-color: #f0f0f0; padding: 20px;} .gr-button {background-color: #4CAF50; color: white;}") as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with Encodec</h1>") |
|
|
|
with gr.Tab("Encode"): |
|
with gr.Row(): |
|
audio_input = gr.Audio(type="filepath", label="Input Audio") |
|
encode_button = gr.Button("Encode", variant="primary") |
|
with gr.Row(): |
|
encoded_output = gr.JSON(label="Encoded Audio") |
|
|
|
encode_button.click(encode, inputs=audio_input, outputs=encoded_output) |
|
|
|
with gr.Tab("Decode"): |
|
with gr.Row(): |
|
encoded_input = gr.JSON(label="Encoded Audio") |
|
decode_button = gr.Button("Decode", variant="primary") |
|
with gr.Row(): |
|
decoded_output = gr.Audio(label="Decompressed Audio") |
|
|
|
decode_button.click(decode, inputs=encoded_input, outputs=decoded_output) |
|
|
|
demo.queue().launch() |
|
|