dac / app.py
owiedotch's picture
Update app.py
0b50165 verified
raw
history blame
No virus
4.99 kB
import gradio as gr
import torch
import torchaudio
from agc import AGC
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import spaces
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the AGC model
@spaces.GPU(duration=180)
def load_agc_model():
return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
agc = load_agc_model()
@spaces.GPU(duration=180)
def encode_audio(audio_file_path):
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Resample to 32kHz if necessary
if sample_rate != 32000:
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
waveform = resampler(waveform)
# Convert to 32 channels if necessary
if waveform.size(0) < 32:
waveform = waveform.repeat(32, 1)[:32, :]
elif waveform.size(0) > 32:
waveform = waveform[:32, :]
# Encode the audio
audio = waveform.unsqueeze(0).to(torch_device)
with torch.no_grad():
z = agc.encode(audio)
# Convert to NumPy and save to a temporary .owie file
z_numpy = z.detach().cpu().numpy()
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd) # Close the file descriptor to avoid issues with os.fdopen
with open(temp_file_path, 'wb') as temp_file:
compressed_data = lz4.frame.compress(z_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
return f"Encoding error: {e}"
@spaces.GPU(duration=180)
def decode_audio(encoded_file_path):
try:
# Load encoded data from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
compressed_data = temp_file.read()
z_numpy_bytes = lz4.frame.decompress(compressed_data)
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
z = torch.from_numpy(z_numpy).to(torch_device)
# Decode the audio
with torch.no_grad():
reconstructed_audio = agc.decode(z)
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), 32000)
return temp_wav_path
except Exception as e:
return f"Decoding error: {e}"
@spaces.GPU(duration=180)
def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
try:
# Load encoded data from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
compressed_data = temp_file.read()
z_numpy_bytes = lz4.frame.decompress(compressed_data)
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
z = torch.from_numpy(z_numpy).to(torch_device)
# Decode the audio in chunks
chunk_size = 32000 # 1 second of audio at 32kHz
sample_rate = 32000 # AGC model's output sample rate
with torch.no_grad():
for i in range(0, z.shape[2], chunk_size):
z_chunk = z[:, :, i:i+chunk_size]
audio_chunk = agc.decode(z_chunk)
# Convert to numpy array (32 channels)
audio_data = audio_chunk.squeeze(0).cpu().numpy()
yield (sample_rate, audio_data)
except Exception as e:
print(f"Streaming decoding error: {e}")
yield (sample_rate, np.zeros((32, chunk_size), dtype=np.float32)) # Return silence in case of error
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with AGC (GPU/CPU) - 32 channels, 32kHz")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath")
encode_button = gr.Button("Encode")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
decode_button = gr.Button("Decode")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
stream_button = gr.Button("Start Streaming")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
demo.queue().launch()