Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,16 @@ from dac_jax.audio_utils import volume_norm, db2linear
|
|
7 |
import io
|
8 |
import soundfile as sf
|
9 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Load the DAC model with padding set to False for chunking
|
12 |
model, variables = dac_jax.load_model(model_type="44khz", padding=False)
|
@@ -23,15 +33,20 @@ def decompress_chunk(c):
|
|
23 |
@spaces.GPU
|
24 |
def encode(audio_file):
|
25 |
try:
|
26 |
-
#
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
signal = jnp.array(signal, dtype=jnp.float32)
|
30 |
while signal.ndim < 3:
|
31 |
signal = jnp.expand_dims(signal, axis=0)
|
32 |
|
33 |
# Set chunk duration based on available GPU memory (adjust as needed)
|
34 |
-
win_duration = 5.0
|
35 |
|
36 |
# Compress using chunking
|
37 |
dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
|
@@ -46,6 +61,9 @@ def encode(audio_file):
|
|
46 |
except Exception as e:
|
47 |
gr.Warning(f"An error occurred during encoding: {e}")
|
48 |
return None
|
|
|
|
|
|
|
49 |
|
50 |
@spaces.GPU
|
51 |
def decode(compressed_dac_file):
|
|
|
7 |
import io
|
8 |
import soundfile as sf
|
9 |
import spaces
|
10 |
+
import tempfile
|
11 |
+
import os
|
12 |
+
|
13 |
+
# Check for CUDA availability and set device
|
14 |
+
try:
|
15 |
+
import jax.tools.colab_tpu
|
16 |
+
jax.tools.colab_tpu.setup_tpu()
|
17 |
+
print("Connected to TPU")
|
18 |
+
except:
|
19 |
+
print("No TPU detected, using GPU or CPU.")
|
20 |
|
21 |
# Load the DAC model with padding set to False for chunking
|
22 |
model, variables = dac_jax.load_model(model_type="44khz", padding=False)
|
|
|
33 |
@spaces.GPU
|
34 |
def encode(audio_file):
|
35 |
try:
|
36 |
+
# Save the uploaded audio to a temporary file
|
37 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio_file:
|
38 |
+
temp_audio_file.write(audio_file.read())
|
39 |
+
temp_audio_file_path = temp_audio_file.name
|
40 |
+
|
41 |
+
# Load a mono audio file from the temporary file path
|
42 |
+
signal, sample_rate = librosa.load(temp_audio_file_path, sr=44100, mono=True)
|
43 |
|
44 |
signal = jnp.array(signal, dtype=jnp.float32)
|
45 |
while signal.ndim < 3:
|
46 |
signal = jnp.expand_dims(signal, axis=0)
|
47 |
|
48 |
# Set chunk duration based on available GPU memory (adjust as needed)
|
49 |
+
win_duration = 5.0
|
50 |
|
51 |
# Compress using chunking
|
52 |
dac_file = model.compress(compress_chunk, signal, sample_rate, win_duration=win_duration)
|
|
|
61 |
except Exception as e:
|
62 |
gr.Warning(f"An error occurred during encoding: {e}")
|
63 |
return None
|
64 |
+
finally:
|
65 |
+
# Clean up the temporary file
|
66 |
+
os.remove(temp_audio_file_path)
|
67 |
|
68 |
@spaces.GPU
|
69 |
def decode(compressed_dac_file):
|