import gradio as gr import jax import jax.numpy as jnp import librosa import dac_jax from dac_jax.audio_utils import volume_norm, db2linear import spaces import tempfile import os import numpy as np # Download a model and bind variables to it. model, variables = dac_jax.load_model(model_type="44khz") model = model.bind(variables) @spaces.GPU def encode(audio_file_path): try: # Load audio with librosa, specifying duration signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True, duration=5.0) # Set duration as needed signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: signal = jnp.expand_dims(signal, axis=0) target_db = -16 # Normalize audio to -16 dB x, input_db = volume_norm(signal, target_db, sample_rate) # Encode audio signal x = model.preprocess(x, sample_rate) z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False) # Save encoded data to a temporary file (using numpy.savez for now) with tempfile.NamedTemporaryFile(delete=False, suffix=".npz") as temp_file: np.savez(temp_file.name, z=z, codes=codes, latents=latents, input_db=input_db) return temp_file.name except Exception as e: gr.Warning(f"An error occurred during encoding: {e}") return None @spaces.GPU def decode(compressed_file_path): # Changed input to compressed_file_path try: # Load encoded data directly from the file path data = np.load(compressed_file_path) # No need for temporary files z = data['z'] codes = data['codes'] latents = data['latents'] input_db = data['input_db'] # Decode audio signal y = model.decode(z, length=z.shape[1] * model.hop_length) # Undo previous loudness normalization y = y * db2linear(input_db - (-16)) # Using -16 as the target_db decoded_audio = np.array(y).squeeze() return (44100, decoded_audio) except Exception as e: gr.Warning(f"An error occurred during decoding: {e}") return None # Gradio interface with gr.Blocks() as demo: gr.Markdown("