import gradio as gr import jax import jax.numpy as jnp import librosa import dac_jax from dac_jax.audio_utils import volume_norm, db2linear import soundfile as sf import spaces import tempfile import os import numpy as np # Check for CUDA availability and set device try: import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu() print("Connected to TPU") except: print("No TPU detected, using GPU or CPU.") # Load the DAC model with padding set to False for chunking model, variables = dac_jax.load_model(model_type="44khz", padding=False) # GPU-accelerated and jit-compiled chunk processing functions @spaces.GPU @jax.jit def compress_chunk(x): return model.apply(variables, x, method='compress_chunk') @spaces.GPU @jax.jit def decompress_chunk(c): return model.apply(variables, c, method='decompress_chunk') def ensure_mono(audio, sr): if audio.ndim > 1: return librosa.to_mono(audio.T), sr return audio, sr @spaces.GPU def encode(audio_file_path): try: # Load and ensure mono audio signal, sample_rate = librosa.load(audio_file_path, sr=44100) signal, sample_rate = ensure_mono(signal, sample_rate) signal = jnp.array(signal, dtype=jnp.float32) signal = jnp.expand_dims(signal, axis=(0, 1)) # Add batch and channel dimensions # Set chunk duration based on available GPU memory (adjust as needed) win_duration = 1.0 # Compress using chunking dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration) # Save the compressed DAC file to a file in the current directory output_path = "compressed_audio.dac" dac_file.save(output_path) return output_path except Exception as e: gr.Warning(f"An error occurred during encoding: {e}") return None @spaces.GPU def decode(compressed_dac_file): try: # Load the compressed DAC file dac_file = dac_jax.DACFile.load(compressed_dac_file.name) # Decompress using chunking y = model.decompress(decompress_chunk, dac_file) # Convert to numpy array and squeeze to remove extra dimensions decoded_audio = np.array(y).squeeze() return (44100, decoded_audio) # Return sample rate and audio data except Exception as e: gr.Warning(f"An error occurred during decoding: {e}") return None # Gradio interface with gr.Blocks() as demo: gr.Markdown("

Audio Compression with DAC-JAX

") with gr.Tab("Encode"): with gr.Row(): audio_input = gr.Audio(type="filepath", label="Input Audio") encode_button = gr.Button("Encode", variant="primary") with gr.Row(): encoded_output = gr.File(label="Compressed Audio (.dac)") encode_button.click(encode, inputs=audio_input, outputs=encoded_output) with gr.Tab("Decode"): with gr.Row(): compressed_input = gr.File(label="Compressed Audio (.dac)") decode_button = gr.Button("Decode", variant="primary") with gr.Row(): decoded_output = gr.Audio(label="Decompressed Audio") decode_button.click(decode, inputs=compressed_input, outputs=decoded_output) demo.queue().launch()