import gradio as gr import jax import jax.numpy as jnp import librosa import dac_jax from dac_jax.audio_utils import volume_norm, db2linear import io import soundfile as sf import spaces # Load the DAC model with padding set to False for chunking model, variables = dac_jax.load_model(model_type="44khz", padding=False) # Jit-compile the chunk processing functions for efficiency @jax.jit def compress_chunk(x): return model.apply(variables, x, method='compress_chunk') @jax.jit def decompress_chunk(c): return model.apply(variables, c, method='decompress_chunk') @spaces.GPU def encode(audio_file_path): try: # Load a mono audio file signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True) signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: signal = jnp.expand_dims(signal, axis=0) # Set chunk duration based on available GPU memory (adjust as needed) win_duration = 0.5 # You might need to experiment with this value # Compress using chunking dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration) # Save the compressed DAC file to BytesIO output = io.BytesIO() dac_file.save(output) output.seek(0) return output except Exception as e: gr.Warning(f"An error occurred during encoding: {e}") return None @spaces.GPU def decode(compressed_dac_file): try: # Load the compressed DAC file dac_file = dac_jax.DACFile.load(compressed_dac_file) # Decompress using chunking y = model.decompress(decompress_chunk, dac_file) # Convert to numpy array and squeeze to remove extra dimensions decoded_audio = np.array(y).squeeze() return decoded_audio except Exception as e: gr.Warning(f"An error occurred during decoding: {e}") return None # Gradio interface with gr.Blocks() as demo: gr.Markdown("

Audio Compression with DAC-JAX

") with gr.Tab("Encode"): with gr.Row(): audio_input = gr.Audio(type="filepath", label="Input Audio") encode_button = gr.Button("Encode", variant="primary") with gr.Row(): encoded_output = gr.File(label="Compressed Audio (.dac)") encode_button.click(encode, inputs=audio_input, outputs=encoded_output) with gr.Tab("Decode"): with gr.Row(): compressed_input = gr.File(label="Compressed Audio (.dac)") decode_button = gr.Button("Decode", variant="primary") with gr.Row(): decoded_output = gr.Audio(label="Decompressed Audio") decode_button.click(decode, inputs=compressed_input, outputs=decoded_output) demo.queue().launch()