|
import gradio as gr |
|
import spaces |
|
import jax.numpy as jnp |
|
import librosa |
|
import dac_jax |
|
from dac_jax.audio_utils import volume_norm, db2linear |
|
import io |
|
import soundfile as sf |
|
|
|
|
|
model, variables = dac_jax.load_model(model_type="44khz") |
|
model = model.bind(variables) |
|
|
|
@spaces.GPU |
|
def encode(audio_file_path): |
|
try: |
|
|
|
signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True) |
|
|
|
signal = jnp.array(signal, dtype=jnp.float32) |
|
while signal.ndim < 3: |
|
signal = jnp.expand_dims(signal, axis=0) |
|
|
|
target_db = -16 |
|
x, input_db = volume_norm(signal, target_db, sample_rate) |
|
|
|
|
|
x = model.preprocess(x, sample_rate) |
|
z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False) |
|
|
|
|
|
output = io.BytesIO() |
|
torch.save({'codes': codes, 'latents': latents, 'input_db': input_db, 'target_db': target_db}, output) |
|
output.seek(0) |
|
|
|
return output |
|
|
|
except Exception as e: |
|
gr.Warning(f"An error occurred during encoding: {e}") |
|
return None |
|
|
|
@spaces.GPU |
|
def decode(encoded_data_file): |
|
try: |
|
|
|
encoded_data = torch.load(encoded_data_file) |
|
codes = encoded_data['codes'] |
|
latents = encoded_data['latents'] |
|
input_db = encoded_data['input_db'] |
|
target_db = encoded_data['target_db'] |
|
|
|
|
|
z = model.quantizer.decode(codes, latents) |
|
y = model.decode(z) |
|
|
|
|
|
y = y * db2linear(input_db - target_db) |
|
|
|
|
|
decoded_audio = np.array(y).squeeze() |
|
|
|
return decoded_audio |
|
|
|
except Exception as e: |
|
gr.Warning(f"An error occurred during decoding: {e}") |
|
return None |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>") |
|
|
|
with gr.Tab("Encode"): |
|
with gr.Row(): |
|
audio_input = gr.Audio(type="filepath", label="Input Audio") |
|
encode_button = gr.Button("Encode", variant="primary") |
|
with gr.Row(): |
|
encoded_output = gr.File(label="Encoded Data") |
|
|
|
encode_button.click(encode, inputs=audio_input, outputs=encoded_output) |
|
|
|
with gr.Tab("Decode"): |
|
with gr.Row(): |
|
encoded_input = gr.File(label="Encoded Data") |
|
decode_button = gr.Button("Decode", variant="primary") |
|
with gr.Row(): |
|
decoded_output = gr.Audio(label="Decompressed Audio") |
|
|
|
decode_button.click(decode, inputs=encoded_input, outputs=decoded_output) |
|
|
|
demo.queue().launch() |
|
|