import gradio as gr import jax import jax.numpy as jnp import librosa import dac_jax from dac_jax.audio_utils import volume_norm, db2linear import io import soundfile as sf import spaces import tempfile import os # Check for CUDA availability and set device try: import jax.tools.colab_tpu jax.tools.colab_tpu.setup_tpu() print("Connected to TPU") except: print("No TPU detected, using GPU or CPU.") # Load the DAC model with padding set to False for chunking model, variables = dac_jax.load_model(model_type="44khz", padding=False) # Jit-compile the chunk processing functions for efficiency @jax.jit def compress_chunk(x): return model.apply(variables, x, method='compress_chunk') @jax.jit def decompress_chunk(c): return model.apply(variables, c, method='decompress_chunk') @spaces.GPU def encode(audio_file): try: # Save the uploaded audio to a temporary file with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file: temp_audio_file.write(audio_file.read()) temp_audio_file_path = temp_audio_file.name # Load a mono audio file from the temporary file path signal, sample_rate = librosa.load(temp_audio_file_path, sr=44100, mono=True) signal = jnp.array(signal, dtype=jnp.float32) while signal.ndim < 3: signal = jnp.expand_dims(signal, axis=0) # Set chunk duration based on available GPU memory (adjust as needed) win_duration = 5.0 # Compress using chunking dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration) # Save the compressed DAC file to BytesIO output = io.BytesIO() dac_file.save(output) output.seek(0) return output except Exception as e: gr.Warning(f"An error occurred during encoding: {e}") return None finally: # Clean up the temporary file os.remove(temp_audio_file_path) @spaces.GPU def decode(compressed_dac_file): try: # Load the compressed DAC file dac_file = dac_jax.DACFile.load(compressed_dac_file) # Decompress using chunking y = model.decompress(decompress_chunk, dac_file) # Convert to numpy array and squeeze to remove extra dimensions decoded_audio = np.array(y).squeeze() return decoded_audio except Exception as e: gr.Warning(f"An error occurred during decoding: {e}") return None # Gradio interface with gr.Blocks() as demo: gr.Markdown("