File size: 2,846 Bytes
f18f98b
60f3b28
 
 
 
ab12e78
60f3b28
c27eb74
60f3b28
 
 
dd53483
60f3b28
ab12e78
c27eb74
60f3b28
 
 
 
 
 
dfdd7ad
60f3b28
 
dfdd7ad
60f3b28
 
 
c27eb74
60f3b28
dfdd7ad
60f3b28
7e97379
c27eb74
60f3b28
c27eb74
 
 
 
 
60f3b28
 
c27eb74
60f3b28
 
 
 
 
 
 
 
 
 
 
 
 
7e97379
60f3b28
 
c27eb74
 
 
 
 
 
 
d07be48
 
60f3b28
44bab11
f18f98b
c27eb74
 
 
 
60f3b28
44bab11
c6e9cf1
44bab11
f18f98b
c27eb74
60f3b28
c27eb74
 
 
44bab11
60f3b28
f18f98b
c27eb74
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
import jax.numpy as jnp
import librosa
import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear
import io
import soundfile as sf 

# Load the DAC model
model, variables = dac_jax.load_model(model_type="44khz")
model = model.bind(variables)

@spaces.GPU
def encode(audio_file_path):
    try:
        # Load a mono audio file
        signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True) 

        signal = jnp.array(signal, dtype=jnp.float32)
        while signal.ndim < 3:
            signal = jnp.expand_dims(signal, axis=0)

        target_db = -16  # Normalize audio to -16 dB
        x, input_db = volume_norm(signal, target_db, sample_rate)

        # Encode audio signal
        x = model.preprocess(x, sample_rate)
        z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)

        # Save the encoded data (codes and latents)
        output = io.BytesIO()
        torch.save({'codes': codes, 'latents': latents, 'input_db': input_db, 'target_db': target_db}, output)
        output.seek(0)

        return output 

    except Exception as e:
        gr.Warning(f"An error occurred during encoding: {e}")
        return None

@spaces.GPU
def decode(encoded_data_file):
    try:
        # Load the encoded data
        encoded_data = torch.load(encoded_data_file)
        codes = encoded_data['codes']
        latents = encoded_data['latents']
        input_db = encoded_data['input_db']
        target_db = encoded_data['target_db']

        # Decode audio signal
        z = model.quantizer.decode(codes, latents)
        y = model.decode(z)

        # Undo previous loudness normalization
        y = y * db2linear(input_db - target_db)

        # Convert to numpy array and squeeze to remove extra dimensions
        decoded_audio = np.array(y).squeeze()

        return decoded_audio

    except Exception as e:
        gr.Warning(f"An error occurred during decoding: {e}")
        return None

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")

    with gr.Tab("Encode"):
        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Input Audio")
            encode_button = gr.Button("Encode", variant="primary")
        with gr.Row():
            encoded_output = gr.File(label="Encoded Data") 

        encode_button.click(encode, inputs=audio_input, outputs=encoded_output)

    with gr.Tab("Decode"):
        with gr.Row():
            encoded_input = gr.File(label="Encoded Data")
            decode_button = gr.Button("Decode", variant="primary")
        with gr.Row():
            decoded_output = gr.Audio(label="Decompressed Audio")

        decode_button.click(decode, inputs=encoded_input, outputs=decoded_output)

demo.queue().launch()