File size: 3,689 Bytes
f18f98b
d8d7a8d
60f3b28
 
 
 
 
d8d7a8d
a2cc897
 
bd40662
 
0586f5f
 
 
 
 
 
 
d8d7a8d
 
 
c27eb74
feebd9f
 
d8d7a8d
 
 
 
feebd9f
d8d7a8d
 
 
dd53483
feebd9f
 
 
 
 
60f3b28
5956a6f
c27eb74
feebd9f
 
 
60f3b28
 
feebd9f
dfdd7ad
d8d7a8d
f84f8af
dfdd7ad
d8d7a8d
 
c27eb74
9651813
 
 
c9a89ac
 
 
c27eb74
 
 
 
 
60f3b28
d8d7a8d
c27eb74
9651813
 
 
 
 
d8d7a8d
9651813
60f3b28
d8d7a8d
 
7e97379
60f3b28
bd40662
c27eb74
9651813
 
 
a2cc897
c27eb74
 
 
 
 
d07be48
 
60f3b28
44bab11
f18f98b
c27eb74
 
 
 
d8d7a8d
44bab11
c6e9cf1
44bab11
f18f98b
c27eb74
d8d7a8d
c27eb74
 
 
44bab11
d8d7a8d
f18f98b
c9a89ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import jax
import jax.numpy as jnp
import librosa
import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear
import soundfile as sf 
import spaces
import tempfile
import os
import numpy as np

# Check for CUDA availability and set device
try:
    import jax.tools.colab_tpu
    jax.tools.colab_tpu.setup_tpu()
    print("Connected to TPU")
except:
    print("No TPU detected, using GPU or CPU.")

# Load the DAC model with padding set to False for chunking
model, variables = dac_jax.load_model(model_type="44khz", padding=False)

# GPU-accelerated and jit-compiled chunk processing functions
@spaces.GPU
@jax.jit
def compress_chunk(x):
    return model.apply(variables, x, method='compress_chunk')

@spaces.GPU
@jax.jit
def decompress_chunk(c):
    return model.apply(variables, c, method='decompress_chunk')

def ensure_mono(audio, sr):
    if audio.ndim > 1:
        return librosa.to_mono(audio.T), sr
    return audio, sr

@spaces.GPU
def encode(audio_file_path):
    try:
        # Load and ensure mono audio
        signal, sample_rate = librosa.load(audio_file_path, sr=44100)
        signal, sample_rate = ensure_mono(signal, sample_rate)

        signal = jnp.array(signal, dtype=jnp.float32)
        signal = jnp.expand_dims(signal, axis=(0, 1))  # Add batch and channel dimensions

        # Set chunk duration based on available GPU memory (adjust as needed)
        win_duration = 5.0

        # Compress using chunking
        dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)

        # Save the compressed DAC file to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
            dac_file.save(temp_file.name)

        # Return the temporary file path
        return temp_file.name

    except Exception as e:
        gr.Warning(f"An error occurred during encoding: {e}")
        return None

@spaces.GPU
def decode(compressed_dac_file):
    try:
        # Create a temporary file to save the uploaded content
        with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
            temp_file.write(compressed_dac_file.file.read())
            temp_file_path = temp_file.name

        # Load the compressed DAC file
        dac_file = dac_jax.DACFile.load(temp_file_path)

        # Decompress using chunking
        y = model.decompress(decompress_chunk, dac_file)

        # Convert to numpy array and squeeze to remove extra dimensions
        decoded_audio = np.array(y).squeeze()

        # Delete the temporary file
        os.unlink(temp_file_path)

        return (44100, decoded_audio)  # Return sample rate and audio data

    except Exception as e:
        gr.Warning(f"An error occurred during decoding: {e}")
        return None

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")

    with gr.Tab("Encode"):
        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Input Audio")
            encode_button = gr.Button("Encode", variant="primary")
        with gr.Row():
            encoded_output = gr.File(label="Compressed Audio (.dac)")

        encode_button.click(encode, inputs=audio_input, outputs=encoded_output)

    with gr.Tab("Decode"):
        with gr.Row():
            compressed_input = gr.File(label="Compressed Audio (.dac)")
            decode_button = gr.Button("Decode", variant="primary")
        with gr.Row():
            decoded_output = gr.Audio(label="Decompressed Audio")

        decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)

demo.queue().launch()