File size: 2,755 Bytes
f18f98b
 
c27eb74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44bab11
f18f98b
c27eb74
 
 
 
 
44bab11
c6e9cf1
44bab11
f18f98b
c27eb74
 
 
 
 
44bab11
c27eb74
f18f98b
c27eb74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import torch
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor
import spaces 

# Load the Encodec model and processor
model = EncodecModel.from_pretrained("facebook/encodec_48khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_48khz")

@spaces.GPU
def encode(audio_file):
    try:
        # Load and preprocess the audio
        audio_sample, sampling_rate = load_dataset("audiofolder", data_dir=audio_file.name, split="train")[0]["audio"]
        inputs = processor(raw_audio=audio_sample, sampling_rate=sampling_rate, return_tensors="pt")

        # Encode the audio
        with torch.no_grad():
            encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])

        # Extract the encoded codes and scales
        audio_codes = encoder_outputs.audio_codes
        audio_scales = encoder_outputs.audio_scales

        # Return the encoded data 
        return {"codes": audio_codes.tolist(), "scales": audio_scales.tolist()}

    except Exception as e:
        gr.Warning(f"An error occurred during encoding: {e}")
        return None

@spaces.GPU
def decode(encoded_data):
    try:
        # Convert the encoded data back to tensors
        audio_codes = torch.tensor(encoded_data["codes"])
        audio_scales = torch.tensor(encoded_data["scales"])

        # Decode the audio
        with torch.no_grad():
            audio_values = model.decode(audio_codes, audio_scales)[0]

        # Convert the decoded audio to a numpy array for Gradio output
        decoded_audio = audio_values.cpu().numpy() 

        return decoded_audio

    except Exception as e:
        gr.Warning(f"An error occurred during decoding: {e}")
        return None

# Gradio interface with improved design
with gr.Blocks(css=".gradio-container {background-color: #f0f0f0; padding: 20px;} .gr-button {background-color: #4CAF50; color: white;}") as demo:
    gr.Markdown("<h1 style='text-align: center;'>Audio Compression with Encodec</h1>")

    with gr.Tab("Encode"):
        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Input Audio")
            encode_button = gr.Button("Encode", variant="primary")
        with gr.Row():
            encoded_output = gr.JSON(label="Encoded Audio")

        encode_button.click(encode, inputs=audio_input, outputs=encoded_output)

    with gr.Tab("Decode"):
        with gr.Row():
            encoded_input = gr.JSON(label="Encoded Audio")
            decode_button = gr.Button("Decode", variant="primary")
        with gr.Row():
            decoded_output = gr.Audio(label="Decompressed Audio")

        decode_button.click(decode, inputs=encoded_input, outputs=decoded_output)

demo.queue().launch()