import gradio as gr import jax import jax.numpy as jnp import librosa import dac_jax import spaces import tempfile import os import numpy as np # Load the DAC model with padding set to False for chunking model, variables = dac_jax.load_model(model_type="44khz", padding=False) # Jit-compile these functions because they're used inside a loop over chunks. @jax.jit def compress_chunk(x): return model.apply(variables, x, method='compress_chunk') @jax.jit def decompress_chunk(c): return model.apply(variables, c, method='decompress_chunk') @spaces.GPU def encode(audio_file_path): try: # Load audio with librosa, specifying duration signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True, duration=.5) # Set duration as needed signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: signal = jnp.expand_dims(signal, axis=0) win_duration = 0.5 # Adjust based on your GPU's memory size dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration) # Save to a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file: dac_file.save(temp_file.name) return temp_file.name except Exception as e: gr.Warning(f"An error occurred during encoding: {e}") return None @spaces.GPU def decode(compressed_dac_file): try: # Load from the uploaded file with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file: temp_file.write(compressed_dac_file.file.read()) temp_file_path = temp_file.name dac_file = dac_jax.DACFile.load(temp_file_path) y = model.decompress(decompress_chunk, dac_file) decoded_audio = np.array(y).squeeze() os.unlink(temp_file_path) return (44100, decoded_audio) except Exception as e: gr.Warning(f"An error occurred during decoding: {e}") return None # Gradio interface with gr.Blocks() as demo: gr.Markdown("