File size: 3,939 Bytes
f18f98b d8d7a8d 60f3b28 d8d7a8d a2cc897 bd40662 0586f5f d8d7a8d c27eb74 feebd9f d8d7a8d feebd9f d8d7a8d dd53483 feebd9f 60f3b28 5956a6f c27eb74 feebd9f 60f3b28 feebd9f dfdd7ad d8d7a8d 9fbeded dfdd7ad d8d7a8d c27eb74 9651813 c9a89ac 9fbeded c9a89ac c27eb74 9fbeded c27eb74 60f3b28 d8d7a8d c27eb74 9651813 d8d7a8d 9651813 60f3b28 d8d7a8d 7e97379 60f3b28 bd40662 c27eb74 9651813 a2cc897 c27eb74 d07be48 60f3b28 44bab11 f18f98b c27eb74 d8d7a8d 44bab11 c6e9cf1 44bab11 f18f98b c27eb74 d8d7a8d c27eb74 44bab11 d8d7a8d f18f98b c9a89ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
import jax
import jax.numpy as jnp
import librosa
import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear
import soundfile as sf
import spaces
import tempfile
import os
import numpy as np
# Check for CUDA availability and set device
try:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
print("Connected to TPU")
except:
print("No TPU detected, using GPU or CPU.")
# Load the DAC model with padding set to False for chunking
model, variables = dac_jax.load_model(model_type="44khz", padding=False)
# GPU-accelerated and jit-compiled chunk processing functions
@spaces.GPU
@jax.jit
def compress_chunk(x):
return model.apply(variables, x, method='compress_chunk')
@spaces.GPU
@jax.jit
def decompress_chunk(c):
return model.apply(variables, c, method='decompress_chunk')
def ensure_mono(audio, sr):
if audio.ndim > 1:
return librosa.to_mono(audio.T), sr
return audio, sr
@spaces.GPU
def encode(audio_file_path):
try:
# Load and ensure mono audio
signal, sample_rate = librosa.load(audio_file_path, sr=44100)
signal, sample_rate = ensure_mono(signal, sample_rate)
signal = jnp.array(signal, dtype=jnp.float32)
signal = jnp.expand_dims(signal, axis=(0, 1)) # Add batch and channel dimensions
# Set chunk duration based on available GPU memory (adjust as needed)
win_duration = 2.0 # Reduced win_duration for testing
# Compress using chunking
dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
# Save the compressed DAC file to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
dac_file.save(temp_file.name)
# Check file size
file_size = os.path.getsize(temp_file.name)
if file_size == 0:
print("WARNING: Compressed file size is 0 bytes!")
# Return the temporary file path
return temp_file.name
except Exception as e:
print(f"ERROR during encoding: {e}")
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(compressed_dac_file):
try:
# Create a temporary file to save the uploaded content
with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
temp_file.write(compressed_dac_file.file.read())
temp_file_path = temp_file.name
# Load the compressed DAC file
dac_file = dac_jax.DACFile.load(temp_file_path)
# Decompress using chunking
y = model.decompress(decompress_chunk, dac_file)
# Convert to numpy array and squeeze to remove extra dimensions
decoded_audio = np.array(y).squeeze()
# Delete the temporary file
os.unlink(temp_file_path)
return (44100, decoded_audio) # Return sample rate and audio data
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.File(label="Compressed Audio (.dac)")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
compressed_input = gr.File(label="Compressed Audio (.dac)")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.queue().launch() |