File size: 3,013 Bytes
f18f98b
d8d7a8d
60f3b28
 
 
a7dbbfe
d8d7a8d
a2cc897
 
bd40662
 
a7dbbfe
 
 
dd53483
60f3b28
5956a6f
c27eb74
5f18cf7
71e9a07
60f3b28
 
5f18cf7
 
dfdd7ad
a7dbbfe
 
c27eb74
a7dbbfe
 
 
 
 
 
 
c9a89ac
 
c27eb74
 
 
 
 
60f3b28
6b093ee
c27eb74
6b093ee
 
a7dbbfe
 
 
 
c27eb74
a7dbbfe
 
 
 
 
 
 
5f18cf7
c27eb74
 
 
 
 
d07be48
 
60f3b28
44bab11
f18f98b
c27eb74
 
 
 
a7dbbfe
44bab11
c6e9cf1
44bab11
f18f98b
c27eb74
a7dbbfe
c27eb74
 
 
44bab11
d8d7a8d
f18f98b
c9a89ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import jax
import jax.numpy as jnp
import librosa
import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear
import spaces
import tempfile
import os
import numpy as np

# Download a model and bind variables to it.
model, variables = dac_jax.load_model(model_type="44khz")
model = model.bind(variables)

@spaces.GPU
def encode(audio_file_path):
    try:
        # Load audio with librosa, specifying duration
        signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True)  # Set duration as needed

        signal = jnp.array(signal, dtype=jnp.float32)
        while signal.ndim < 3:
            signal = jnp.expand_dims(signal, axis=0)

        target_db = -16  # Normalize audio to -16 dB
        x, input_db = volume_norm(signal, target_db, sample_rate)

        # Encode audio signal
        x = model.preprocess(x, sample_rate)
        z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)

        # Save encoded data to a temporary file (using numpy.savez for now)
        with tempfile.NamedTemporaryFile(delete=False, suffix=".npz") as temp_file:
            np.savez(temp_file.name, z=z, codes=codes, latents=latents, input_db=input_db)

        return temp_file.name

    except Exception as e:
        gr.Warning(f"An error occurred during encoding: {e}")
        return None

@spaces.GPU
def decode(compressed_file_path):  # Changed input to compressed_file_path
    try:
        # Load encoded data directly from the file path
        data = np.load(compressed_file_path)  # No need for temporary files
        z = data['z']
        codes = data['codes']
        latents = data['latents']
        input_db = data['input_db']

        # Decode audio signal
        y = model.decode(z, length=z.shape[1] * model.hop_length)

        # Undo previous loudness normalization
        y = y * db2linear(input_db - (-16))  # Using -16 as the target_db

        decoded_audio = np.array(y).squeeze()
        return (44100, decoded_audio)

    except Exception as e:
        gr.Warning(f"An error occurred during decoding: {e}")
        return None

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")

    with gr.Tab("Encode"):
        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Input Audio")
            encode_button = gr.Button("Encode", variant="primary")
        with gr.Row():
            encoded_output = gr.File(label="Compressed Audio (.npz)")

        encode_button.click(encode, inputs=audio_input, outputs=encoded_output)

    with gr.Tab("Decode"):
        with gr.Row():
            compressed_input = gr.File(label="Compressed Audio (.npz)")
            decode_button = gr.Button("Decode", variant="primary")
        with gr.Row():
            decoded_output = gr.Audio(label="Decompressed Audio")

        decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)

demo.queue().launch()