owiedotch commited on
Commit
60f3b28
1 Parent(s): dd53483

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -34
app.py CHANGED
@@ -1,52 +1,62 @@
1
  import gradio as gr
2
- import spaces
3
- import torch
4
- import torchaudio
5
- from encodec import EncodecModel
6
- from encodec.utils import convert_audio
7
- from encodec.compress import compress_to_file, decompress_from_file
8
  import io
 
9
 
10
- # Check for CUDA availability and set device
11
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
- print(f"Using device: {device}")
13
 
14
- # Load the Encodec model and move it to the selected device
15
- model = EncodecModel.encodec_model_48khz().to(device)
16
- model.set_target_bandwidth(6.0)
17
-
18
- @spaces.GPU # Indicate GPU usage for Spaces environment (if applicable)
19
  def encode(audio_file_path):
20
  try:
21
- # Load and pre-process the audio waveform
22
- wav, sr = torchaudio.load(audio_file_path)
 
 
 
 
23
 
24
- # Convert to mono if necessary
25
- if wav.shape[0] > 1:
26
- wav = torch.mean(wav, dim=0, keepdim=True)
27
 
28
- wav = convert_audio(wav, sr, model.sample_rate, model.channels)
29
- wav = wav.to(device) # Move the input audio to the selected device
 
30
 
31
- # Compress to ecdc file in memory
32
  output = io.BytesIO()
33
- compress_to_file(model, wav, output)
34
  output.seek(0)
35
 
36
- return output
37
 
38
  except Exception as e:
39
  gr.Warning(f"An error occurred during encoding: {e}")
40
  return None
41
 
42
- @spaces.GPU
43
- def decode(compressed_audio_file):
44
  try:
45
- # Decompress audio
46
- wav, sr = decompress_from_file(compressed_audio_file, device=device) # Pass the device to decompress_from_file
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Convert the decoded audio to a numpy array for Gradio output
49
- decoded_audio = wav.cpu().numpy()
50
 
51
  return decoded_audio
52
 
@@ -56,24 +66,24 @@ def decode(compressed_audio_file):
56
 
57
  # Gradio interface
58
  with gr.Blocks() as demo:
59
- gr.Markdown("<h1 style='text-align: center;'>Audio Compression with Encodec</h1>")
60
 
61
  with gr.Tab("Encode"):
62
  with gr.Row():
63
  audio_input = gr.Audio(type="filepath", label="Input Audio")
64
  encode_button = gr.Button("Encode", variant="primary")
65
  with gr.Row():
66
- encoded_output = gr.File(label="Compressed Audio (.ecdc)")
67
 
68
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
69
 
70
  with gr.Tab("Decode"):
71
  with gr.Row():
72
- compressed_input = gr.File(label="Compressed Audio (.ecdc)")
73
  decode_button = gr.Button("Decode", variant="primary")
74
  with gr.Row():
75
  decoded_output = gr.Audio(label="Decompressed Audio")
76
 
77
- decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
78
 
79
  demo.queue().launch()
 
1
  import gradio as gr
2
+ import jax.numpy as jnp
3
+ import librosa
4
+ import dac_jax
5
+ from dac_jax.audio_utils import volume_norm, db2linear
 
 
6
  import io
7
+ import soundfile as sf
8
 
9
+ # Load the DAC model
10
+ model, variables = dac_jax.load_model(model_type="44khz")
11
+ model = model.bind(variables)
12
 
13
+ @spaces.GPU
 
 
 
 
14
  def encode(audio_file_path):
15
  try:
16
+ # Load a mono audio file
17
+ signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True)
18
+
19
+ signal = jnp.array(signal, dtype=jnp.float32)
20
+ while signal.ndim < 3:
21
+ signal = jnp.expand_dims(signal, axis=0)
22
 
23
+ target_db = -16 # Normalize audio to -16 dB
24
+ x, input_db = volume_norm(signal, target_db, sample_rate)
 
25
 
26
+ # Encode audio signal
27
+ x = model.preprocess(x, sample_rate)
28
+ z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)
29
 
30
+ # Save the encoded data (codes and latents)
31
  output = io.BytesIO()
32
+ torch.save({'codes': codes, 'latents': latents, 'input_db': input_db, 'target_db': target_db}, output)
33
  output.seek(0)
34
 
35
+ return output
36
 
37
  except Exception as e:
38
  gr.Warning(f"An error occurred during encoding: {e}")
39
  return None
40
 
41
+ @spaces.GPU
42
+ def decode(encoded_data_file):
43
  try:
44
+ # Load the encoded data
45
+ encoded_data = torch.load(encoded_data_file)
46
+ codes = encoded_data['codes']
47
+ latents = encoded_data['latents']
48
+ input_db = encoded_data['input_db']
49
+ target_db = encoded_data['target_db']
50
+
51
+ # Decode audio signal
52
+ z = model.quantizer.decode(codes, latents)
53
+ y = model.decode(z)
54
+
55
+ # Undo previous loudness normalization
56
+ y = y * db2linear(input_db - target_db)
57
 
58
+ # Convert to numpy array and squeeze to remove extra dimensions
59
+ decoded_audio = np.array(y).squeeze()
60
 
61
  return decoded_audio
62
 
 
66
 
67
  # Gradio interface
68
  with gr.Blocks() as demo:
69
+ gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
70
 
71
  with gr.Tab("Encode"):
72
  with gr.Row():
73
  audio_input = gr.Audio(type="filepath", label="Input Audio")
74
  encode_button = gr.Button("Encode", variant="primary")
75
  with gr.Row():
76
+ encoded_output = gr.File(label="Encoded Data")
77
 
78
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
79
 
80
  with gr.Tab("Decode"):
81
  with gr.Row():
82
+ encoded_input = gr.File(label="Encoded Data")
83
  decode_button = gr.Button("Decode", variant="primary")
84
  with gr.Row():
85
  decoded_output = gr.Audio(label="Decompressed Audio")
86
 
87
+ decode_button.click(decode, inputs=encoded_input, outputs=decoded_output)
88
 
89
  demo.queue().launch()