|
import gradio as gr |
|
import jax |
|
import jax.numpy as jnp |
|
import librosa |
|
import dac_jax |
|
from dac_jax.audio_utils import volume_norm, db2linear |
|
import soundfile as sf |
|
import spaces |
|
import tempfile |
|
import os |
|
import numpy as np |
|
|
|
|
|
try: |
|
import jax.tools.colab_tpu |
|
jax.tools.colab_tpu.setup_tpu() |
|
print("Connected to TPU") |
|
except: |
|
print("No TPU detected, using GPU or CPU.") |
|
|
|
|
|
model, variables = dac_jax.load_model(model_type="44khz", padding=False) |
|
|
|
|
|
@spaces.GPU |
|
@jax.jit |
|
def compress_chunk(x): |
|
return model.apply(variables, x, method='compress_chunk') |
|
|
|
@spaces.GPU |
|
@jax.jit |
|
def decompress_chunk(c): |
|
return model.apply(variables, c, method='decompress_chunk') |
|
|
|
def ensure_mono(audio, sr): |
|
if audio.ndim > 1: |
|
return librosa.to_mono(audio.T), sr |
|
return audio, sr |
|
|
|
@spaces.GPU |
|
def encode(audio_file_path): |
|
try: |
|
|
|
signal, sample_rate = librosa.load(audio_file_path, sr=44100) |
|
signal, sample_rate = ensure_mono(signal, sample_rate) |
|
|
|
signal = jnp.array(signal, dtype=jnp.float32) |
|
signal = jnp.expand_dims(signal, axis=(0, 1)) |
|
|
|
|
|
win_duration = 1.0 |
|
|
|
|
|
dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration) |
|
|
|
|
|
output_path = "compressed_audio.dac" |
|
dac_file.save(output_path) |
|
|
|
return output_path |
|
|
|
except Exception as e: |
|
gr.Warning(f"An error occurred during encoding: {e}") |
|
return None |
|
|
|
@spaces.GPU |
|
def decode(compressed_dac_file): |
|
try: |
|
|
|
dac_file = dac_jax.DACFile.load(compressed_dac_file.name) |
|
|
|
|
|
y = model.decompress(decompress_chunk, dac_file) |
|
|
|
|
|
decoded_audio = np.array(y).squeeze() |
|
|
|
return (44100, decoded_audio) |
|
|
|
except Exception as e: |
|
gr.Warning(f"An error occurred during decoding: {e}") |
|
return None |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>") |
|
|
|
with gr.Tab("Encode"): |
|
with gr.Row(): |
|
audio_input = gr.Audio(type="filepath", label="Input Audio") |
|
encode_button = gr.Button("Encode", variant="primary") |
|
with gr.Row(): |
|
encoded_output = gr.File(label="Compressed Audio (.dac)") |
|
|
|
encode_button.click(encode, inputs=audio_input, outputs=encoded_output) |
|
|
|
with gr.Tab("Decode"): |
|
with gr.Row(): |
|
compressed_input = gr.File(label="Compressed Audio (.dac)") |
|
decode_button = gr.Button("Decode", variant="primary") |
|
with gr.Row(): |
|
decoded_output = gr.Audio(label="Decompressed Audio") |
|
|
|
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output) |
|
|
|
demo.queue().launch() |