owiedotch commited on
Commit
bd40662
1 Parent(s): a2cc897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -3
app.py CHANGED
@@ -9,6 +9,10 @@ import soundfile as sf
9
  import spaces
10
  import tempfile
11
  import os
 
 
 
 
12
 
13
  # Check for CUDA availability and set device
14
  try:
@@ -32,6 +36,7 @@ def decompress_chunk(c):
32
 
33
  @spaces.GPU
34
  def encode(audio_file_path):
 
35
  try:
36
  # Load a mono audio file directly from the file path
37
  signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True)
@@ -41,16 +46,20 @@ def encode(audio_file_path):
41
  signal = jnp.expand_dims(signal, axis=0)
42
 
43
  # Set chunk duration based on available GPU memory (adjust as needed)
44
- win_duration = 5.0
45
 
46
  # Compress using chunking
47
  dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
48
 
49
  # Save the compressed DAC file to a temporary file
 
 
 
50
  with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
51
  dac_file.save(temp_file.name)
 
52
 
53
- return temp_file.name
54
 
55
  except Exception as e:
56
  gr.Warning(f"An error occurred during encoding: {e}")
@@ -66,7 +75,7 @@ def decode(compressed_dac_file):
66
  y = model.decompress(decompress_chunk, dac_file)
67
 
68
  # Convert to numpy array and squeeze to remove extra dimensions
69
- decoded_audio = jnp.array(y).squeeze()
70
 
71
  return (44100, decoded_audio) # Return sample rate and audio data
72
 
@@ -74,6 +83,13 @@ def decode(compressed_dac_file):
74
  gr.Warning(f"An error occurred during decoding: {e}")
75
  return None
76
 
 
 
 
 
 
 
 
77
  # Gradio interface
78
  with gr.Blocks() as demo:
79
  gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
@@ -86,6 +102,7 @@ with gr.Blocks() as demo:
86
  encoded_output = gr.File(label="Compressed Audio (.dac)")
87
 
88
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
 
89
 
90
  with gr.Tab("Decode"):
91
  with gr.Row():
 
9
  import spaces
10
  import tempfile
11
  import os
12
+ import numpy as np
13
+
14
+ # Global variable to store the temporary file path
15
+ temp_file_path = None
16
 
17
  # Check for CUDA availability and set device
18
  try:
 
36
 
37
  @spaces.GPU
38
  def encode(audio_file_path):
39
+ global temp_file_path
40
  try:
41
  # Load a mono audio file directly from the file path
42
  signal, sample_rate = librosa.load(audio_file_path, sr=44100, mono=True)
 
46
  signal = jnp.expand_dims(signal, axis=0)
47
 
48
  # Set chunk duration based on available GPU memory (adjust as needed)
49
+ win_duration = 10.0
50
 
51
  # Compress using chunking
52
  dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
53
 
54
  # Save the compressed DAC file to a temporary file
55
+ if temp_file_path:
56
+ os.remove(temp_file_path) # Remove the previous temporary file if it exists
57
+
58
  with tempfile.NamedTemporaryFile(delete=False, suffix=".dac") as temp_file:
59
  dac_file.save(temp_file.name)
60
+ temp_file_path = temp_file.name
61
 
62
+ return temp_file_path
63
 
64
  except Exception as e:
65
  gr.Warning(f"An error occurred during encoding: {e}")
 
75
  y = model.decompress(decompress_chunk, dac_file)
76
 
77
  # Convert to numpy array and squeeze to remove extra dimensions
78
+ decoded_audio = np.array(y).squeeze()
79
 
80
  return (44100, decoded_audio) # Return sample rate and audio data
81
 
 
83
  gr.Warning(f"An error occurred during decoding: {e}")
84
  return None
85
 
86
+ def cleanup(audio_file_path):
87
+ global temp_file_path
88
+ if temp_file_path and os.path.exists(temp_file_path):
89
+ os.remove(temp_file_path)
90
+ temp_file_path = None
91
+ return None
92
+
93
  # Gradio interface
94
  with gr.Blocks() as demo:
95
  gr.Markdown("<h1 style='text-align: center;'>Audio Compression with DAC-JAX</h1>")
 
102
  encoded_output = gr.File(label="Compressed Audio (.dac)")
103
 
104
  encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
105
+ encoded_output.change(cleanup, inputs=[audio_input], outputs=None)
106
 
107
  with gr.Tab("Decode"):
108
  with gr.Row():