File size: 2,860 Bytes
f18f98b 76e481a 60f3b28 ab12e78 60f3b28 c27eb74 60f3b28 dd53483 60f3b28 ab12e78 c27eb74 60f3b28 dfdd7ad 60f3b28 dfdd7ad 60f3b28 c27eb74 60f3b28 dfdd7ad 60f3b28 7e97379 c27eb74 60f3b28 c27eb74 60f3b28 c27eb74 60f3b28 7e97379 60f3b28 c27eb74 d07be48 60f3b28 44bab11 f18f98b c27eb74 60f3b28 44bab11 c6e9cf1 44bab11 f18f98b c27eb74 60f3b28 c27eb74 44bab11 60f3b28 f18f98b c27eb74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import spaces
import jax.numpy as jnp
import librosa
import dac_jax
from dac_jax.audio_utils import volume_norm, db2linear
import io
import soundfile as sf
# Load the DAC model
model, variables = dac_jax.load_model(model_type="44khz")
model = model.bind(variables)
@spaces.GPU
def encode(audio_file_path):
try:
# Load a mono audio file
signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True)
signal = jnp.array(signal, dtype=jnp.float32)
while signal.ndim < 3:
signal = jnp.expand_dims(signal, axis=0)
target_db = -16 # Normalize audio to -16 dB
x, input_db = volume_norm(signal, target_db, sample_rate)
# Encode audio signal
x = model.preprocess(x, sample_rate)
z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)
# Save the encoded data (codes and latents)
output = io.BytesIO()
torch.save({'codes': codes, 'latents': latents, 'input_db': input_db, 'target_db': target_db}, output)
output.seek(0)
return output
except Exception as e:
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(encoded_data_file):
try:
# Load the encoded data
encoded_data = torch.load(encoded_data_file)
codes = encoded_data['codes']
latents = encoded_data['latents']
input_db = encoded_data['input_db']
target_db = encoded_data['target_db']
# Decode audio signal
z = model.quantizer.decode(codes, latents)
y = model.decode(z)
# Undo previous loudness normalization
y = y * db2linear(input_db - target_db)
# Convert to numpy array and squeeze to remove extra dimensions
decoded_audio = np.array(y).squeeze()
return decoded_audio
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.File(label="Encoded Data")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
encoded_input = gr.File(label="Encoded Data")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=encoded_input, outputs=decoded_output)
demo.queue().launch()
|