File size: 2,755 Bytes
f18f98b c27eb74 44bab11 f18f98b c27eb74 44bab11 c6e9cf1 44bab11 f18f98b c27eb74 44bab11 c27eb74 f18f98b c27eb74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
import torch
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor
import spaces
# Load the Encodec model and processor
model = EncodecModel.from_pretrained("facebook/encodec_48khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")
@spaces.GPU
def encode(audio_file):
try:
# Load and preprocess the audio
audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file.name, split="train")[0]["audio"]
inputs = processor(raw_audio=audio_sample, sampling_rate=sampling_rate, return_tensors="pt")
# Encode the audio
with torch.no_grad():
encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
# Extract the encoded codes and scales
audio_codes = encoder_outputs.audio_codes
audio_scales = encoder_outputs.audio_scales
# Return the encoded data
return {"codes": audio_codes.tolist(), "scales": audio_scales.tolist()}
except Exception as e:
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(encoded_data):
try:
# Convert the encoded data back to tensors
audio_codes = torch.tensor(encoded_data["codes"])
audio_scales = torch.tensor(encoded_data["scales"])
# Decode the audio
with torch.no_grad():
audio_values = model.decode(audio_codes, audio_scales)[0]
# Convert the decoded audio to a numpy array for Gradio output
decoded_audio = audio_values.cpu().numpy()
return decoded_audio
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface with improved design
with gr.Blocks(css=".gradio-container {background-color: #f0f0f0; padding: 20px;} .gr-button {background-color: #4CAF50; color: white;}") as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with Encodec</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.JSON(label="Encoded Audio")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
encoded_input = gr.JSON(label="Encoded Audio")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=encoded_input, outputs=decoded_output)
demo.queue().launch()
|