owiedotch commited on
Commit
d8d7a8d
1 Parent(s): 76e481a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -28
app.py CHANGED
@@ -1,15 +1,24 @@
1
  import gradio as gr
2
- import spaces
3
  import jax.numpy as jnp
4
  import librosa
5
  import dac_jax
6
  from dac_jax.audio_utils import volume_norm, db2linear
7
  import io
8
  import soundfile as sf
 
 
 
 
9
 
10
- # Load the DAC model
11
- model, variables = dac_jax.load_model(model_type="44khz")
12
- model = model.bind(variables)
 
 
 
 
 
13
 
14
  @spaces.GPU
15
  def encode(audio_file_path):
@@ -21,40 +30,31 @@ def encode(audio_file_path):
21
  while signal.ndim < 3:
22
  signal = jnp.expand_dims(signal, axis=0)
23
 
24
- target_db = -16 # Normalize audio to -16 dB
25
- x, input_db = volume_norm(signal, target_db, sample_rate)
26
 
27
- # Encode audio signal
28
- x = model.preprocess(x, sample_rate)
29
- z, codes, latents, commitment_loss, codebook_loss = model.encode(x, train=False)
30
 
31
- # Save the encoded data (codes and latents)
32
  output = io.BytesIO()
33
- torch.save({'codes': codes, 'latents': latents, 'input_db': input_db, 'target_db': target_db}, output)
34
  output.seek(0)
35
 
36
- return output
37
 
38
  except Exception as e:
39
  gr.Warning(f"An error occurred during encoding: {e}")
40
  return None
41
 
42
  @spaces.GPU
43
- def decode(encoded_data_file):
44
  try:
45
- # Load the encoded data
46
- encoded_data = torch.load(encoded_data_file)
47
- codes = encoded_data['codes']
48
- latents = encoded_data['latents']
49
- input_db = encoded_data['input_db']
50
- target_db = encoded_data['target_db']
51
-
52
- # Decode audio signal
53
- z = model.quantizer.decode(codes, latents)
54
- y = model.decode(z)
55
 
56
- # Undo previous loudness normalization
57
- y = y * db2linear(input_db - target_db)
58
 
59
  # Convert to numpy array and squeeze to remove extra dimensions
60
  decoded_audio = np.array(y).squeeze()
@@ -74,17 +74,17 @@ with gr.Blocks() as demo:
74
  audio_input = gr.Audio(type="filepath", label="Input Audio")
75
  encode_button = gr.Button("Encode", variant="primary")
76
  with gr.Row():
77
- encoded_output = gr.File(label="Encoded Data")
78
 
79
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
80
 
81
  with gr.Tab("Decode"):
82
  with gr.Row():
83
- encoded_input = gr.File(label="Encoded Data")
84
  decode_button = gr.Button("Decode", variant="primary")
85
  with gr.Row():
86
  decoded_output = gr.Audio(label="Decompressed Audio")
87
 
88
- decode_button.click(decode, inputs=encoded_input, outputs=decoded_output)
89
 
90
  demo.queue().launch()
 
1
  import gradio as gr
2
+ import jax
3
  import jax.numpy as jnp
4
  import librosa
5
  import dac_jax
6
  from dac_jax.audio_utils import volume_norm, db2linear
7
  import io
8
  import soundfile as sf
9
+ import spaces
10
+
11
+ # Load the DAC model with padding set to False for chunking
12
+ model, variables = dac_jax.load_model(model_type="44khz", padding=False)
13
 
14
+ # Jit-compile the chunk processing functions for efficiency
15
+ @jax.jit
16
+ def compress_chunk(x):
17
+ return model.apply(variables, x, method='compress_chunk')
18
+
19
+ @jax.jit
20
+ def decompress_chunk(c):
21
+ return model.apply(variables, c, method='decompress_chunk')
22
 
23
  @spaces.GPU
24
  def encode(audio_file_path):
 
30
  while signal.ndim < 3:
31
  signal = jnp.expand_dims(signal, axis=0)
32
 
33
+ # Set chunk duration based on available GPU memory (adjust as needed)
34
+ win_duration = 0.5 # You might need to experiment with this value
35
 
36
+ # Compress using chunking
37
+ dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
 
38
 
39
+ # Save the compressed DAC file to BytesIO
40
  output = io.BytesIO()
41
+ dac_file.save(output)
42
  output.seek(0)
43
 
44
+ return output
45
 
46
  except Exception as e:
47
  gr.Warning(f"An error occurred during encoding: {e}")
48
  return None
49
 
50
  @spaces.GPU
51
+ def decode(compressed_dac_file):
52
  try:
53
+ # Load the compressed DAC file
54
+ dac_file = dac_jax.DACFile.load(compressed_dac_file)
 
 
 
 
 
 
 
 
55
 
56
+ # Decompress using chunking
57
+ y = model.decompress(decompress_chunk, dac_file)
58
 
59
  # Convert to numpy array and squeeze to remove extra dimensions
60
  decoded_audio = np.array(y).squeeze()
 
74
  audio_input = gr.Audio(type="filepath", label="Input Audio")
75
  encode_button = gr.Button("Encode", variant="primary")
76
  with gr.Row():
77
+ encoded_output = gr.File(label="Compressed Audio (.dac)")
78
 
79
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
80
 
81
  with gr.Tab("Decode"):
82
  with gr.Row():
83
+ compressed_input = gr.File(label="Compressed Audio (.dac)")
84
  decode_button = gr.Button("Decode", variant="primary")
85
  with gr.Row():
86
  decoded_output = gr.Audio(label="Decompressed Audio")
87
 
88
+ decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
89
 
90
  demo.queue().launch()