dac / app.py
owiedotch's picture
Update app.py
c9a89ac verified
raw
history blame
No virus
3.69 kB
import gradio as gr
import jax
import jax.numpy as jnp
import librosa
import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear
import soundfile as sf
import spaces
import tempfile
import os
import numpy as np
# Check for CUDA availability and set device
try:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print("Connected to TPU")
except:
print("No TPU detected, using GPU or CPU.")
# Load the DAC model with padding set to False for chunking
model, variables = dac_jax.load_model(model_type="44khz", padding=False)
# GPU-accelerated and jit-compiled chunk processing functions
@spaces.GPU
@jax.jit
def compress_chunk(x):
return model.apply(variables, x, method='compress_chunk')
@spaces.GPU
@jax.jit
def decompress_chunk(c):
return model.apply(variables, c, method='decompress_chunk')
def ensure_mono(audio, sr):
if audio.ndim > 1:
return librosa.to_mono(audio.T), sr
return audio, sr
@spaces.GPU
def encode(audio_file_path):
try:
# Load and ensure mono audio
signal, sample_rate = librosa.load(audio_file_path, sr=44100)
signal, sample_rate = ensure_mono(signal, sample_rate)
signal = jnp.array(signal, dtype=jnp.float32)
signal = jnp.expand_dims(signal, axis=(0, 1)) # Add batch and channel dimensions
# Set chunk duration based on available GPU memory (adjust as needed)
win_duration = 5.0
# Compress using chunking
dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
# Save the compressed DAC file to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
dac_file.save(temp_file.name)
# Return the temporary file path
return temp_file.name
except Exception as e:
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(compressed_dac_file):
try:
# Create a temporary file to save the uploaded content
with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
temp_file.write(compressed_dac_file.file.read())
temp_file_path = temp_file.name
# Load the compressed DAC file
dac_file = dac_jax.DACFile.load(temp_file_path)
# Decompress using chunking
y = model.decompress(decompress_chunk, dac_file)
# Convert to numpy array and squeeze to remove extra dimensions
decoded_audio = np.array(y).squeeze()
# Delete the temporary file
os.unlink(temp_file_path)
return (44100, decoded_audio) # Return sample rate and audio data
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.File(label="Compressed Audio (.dac)")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
compressed_input = gr.File(label="Compressed Audio (.dac)")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.queue().launch()