File size: 2,896 Bytes
f18f98b d8d7a8d 60f3b28 d8d7a8d a2cc897 bd40662 d8d7a8d c27eb74 5f18cf7 d8d7a8d dd53483 60f3b28 5956a6f c27eb74 5f18cf7 60f3b28 5f18cf7 dfdd7ad 5f18cf7 d8d7a8d c27eb74 5f18cf7 9651813 c9a89ac c27eb74 60f3b28 d8d7a8d c27eb74 5f18cf7 9651813 d8d7a8d bd40662 c27eb74 9651813 5f18cf7 c27eb74 d07be48 60f3b28 44bab11 f18f98b c27eb74 d8d7a8d 44bab11 c6e9cf1 44bab11 f18f98b c27eb74 d8d7a8d c27eb74 44bab11 d8d7a8d f18f98b c9a89ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import gradio as gr
import jax
import jax.numpy as jnp
import librosa
import dac_jax
import spaces
import tempfile
import os
import numpy as np
# Load the DAC model with padding set to False for chunking
model, variables = dac_jax.load_model(model_type="44khz", padding=False)
# Jit-compile these functions because they're used inside a loop over chunks.
@jax.jit
def compress_chunk(x):
return model.apply(variables, x, method='compress_chunk')
@jax.jit
def decompress_chunk(c):
return model.apply(variables, c, method='decompress_chunk')
@spaces.GPU
def encode(audio_file_path):
try:
# Load audio with librosa, specifying duration
signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True, duration=.5) # Set duration as needed
signal = jnp.array(signal, dtype=jnp.float32)
while signal.ndim < 3:
signal = jnp.expand_dims(signal, axis=0)
win_duration = 0.5 # Adjust based on your GPU's memory size
dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
# Save to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
dac_file.save(temp_file.name)
return temp_file.name
except Exception as e:
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(compressed_dac_file):
try:
# Load from the uploaded file
with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
temp_file.write(compressed_dac_file.file.read())
temp_file_path = temp_file.name
dac_file = dac_jax.DACFile.load(temp_file_path)
y = model.decompress(decompress_chunk, dac_file)
decoded_audio = np.array(y).squeeze()
os.unlink(temp_file_path)
return (44100, decoded_audio)
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.File(label="Compressed Audio (.dac)")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
compressed_input = gr.File(label="Compressed Audio (.dac)")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.queue().launch() |