import gradio as gr import jax.numpy as jnp import librosa import dac_jax from dac_jax.audio_utils import volume_norm, db2linear import io import soundfile as sf # Load the DAC model model, variables = dac_jax.load_model(model_type="44khz") model = model.bind(variables) @spaces.GPU def encode(audio_file_path): try: # Load a mono audio file signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True) signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: signal = jnp.expand_dims(signal, axis=0) target_db = -16 # Normalize audio to -16 dB x, input_db = volume_norm(signal, target_db, sample_rate) # Encode audio signal x = model.preprocess(x, sample_rate) z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False) # Save the encoded data (codes and latents) output = io.BytesIO() torch.save({'codes': codes, 'latents': latents, 'input_db': input_db, 'target_db': target_db}, output) output.seek(0) return output except Exception as e: gr.Warning(f"An error occurred during encoding: {e}") return None @spaces.GPU def decode(encoded_data_file): try: # Load the encoded data encoded_data = torch.load(encoded_data_file) codes = encoded_data['codes'] latents = encoded_data['latents'] input_db = encoded_data['input_db'] target_db = encoded_data['target_db'] # Decode audio signal z = model.quantizer.decode(codes, latents) y = model.decode(z) # Undo previous loudness normalization y = y * db2linear(input_db - target_db) # Convert to numpy array and squeeze to remove extra dimensions decoded_audio = np.array(y).squeeze() return decoded_audio except Exception as e: gr.Warning(f"An error occurred during decoding: {e}") return None # Gradio interface with gr.Blocks() as demo: gr.Markdown("

Audio Compression with DAC-JAX

") with gr.Tab("Encode"): with gr.Row(): audio_input = gr.Audio(type="filepath", label="Input Audio") encode_button = gr.Button("Encode", variant="primary") with gr.Row(): encoded_output = gr.File(label="Encoded Data") encode_button.click(encode, inputs=audio_input, outputs=encoded_output) with gr.Tab("Decode"): with gr.Row(): encoded_input = gr.File(label="Encoded Data") decode_button = gr.Button("Decode", variant="primary") with gr.Row(): decoded_output = gr.Audio(label="Decompressed Audio") decode_button.click(decode, inputs=encoded_input, outputs=decoded_output) demo.queue().launch()