dac / app.py
owiedotch's picture
Update app.py
6b093ee verified
raw
history blame
No virus
3.03 kB
import gradio as gr
import jax
import jax.numpy as jnp
import librosa
import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear
import spaces
import tempfile
import os
import numpy as np
# Download a model and bind variables to it.
model, variables = dac_jax.load_model(model_type="44khz")
model = model.bind(variables)
@spaces.GPU
def encode(audio_file_path):
try:
# Load audio with librosa, specifying duration
signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True, duration=5.0) # Set duration as needed
signal = jnp.array(signal, dtype=jnp.float32)
while signal.ndim < 3:
signal = jnp.expand_dims(signal, axis=0)
target_db = -16 # Normalize audio to -16 dB
x, input_db = volume_norm(signal, target_db, sample_rate)
# Encode audio signal
x = model.preprocess(x, sample_rate)
z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)
# Save encoded data to a temporary file (using numpy.savez for now)
with tempfile.NamedTemporaryFile(delete=False, suffix=".npz") as temp_file:
np.savez(temp_file.name, z=z, codes=codes, latents=latents, input_db=input_db)
return temp_file.name
except Exception as e:
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(compressed_file_path): # Changed input to compressed_file_path
try:
# Load encoded data directly from the file path
data = np.load(compressed_file_path) # No need for temporary files
z = data['z']
codes = data['codes']
latents = data['latents']
input_db = data['input_db']
# Decode audio signal
y = model.decode(z, length=z.shape[1] * model.hop_length)
# Undo previous loudness normalization
y = y * db2linear(input_db - (-16)) # Using -16 as the target_db
decoded_audio = np.array(y).squeeze()
return (44100, decoded_audio)
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.File(label="Compressed Audio (.npz)")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
compressed_input = gr.File(label="Compressed Audio (.npz)")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.queue().launch()