dac / app.py
owiedotch's picture
Update app.py
eb0f782 verified
raw
history blame
No virus
4.49 kB
import gradio as gr
import torch
import torchaudio
from agc import AGC
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import spaces
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the AGC model
@spaces.GPU(duration=180)
def load_agc_model():
return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
agc = load_agc_model()
@spaces.GPU(duration=180)
def encode_audio(audio_file_path):
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Convert to stereo if necessary
if waveform.size(0) == 1:
waveform = waveform.repeat(2, 1)
# Encode the audio
audio = waveform.unsqueeze(0).to(torch_device)
with torch.no_grad():
z = agc.encode(audio)
# Convert to NumPy and save to a temporary .owie file
z_numpy = z.detach().cpu().numpy()
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd) # Close the file descriptor to avoid issues with os.fdopen
with open(temp_file_path, 'wb') as temp_file:
compressed_data = lz4.frame.compress(z_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
return f"Encoding error: {e}"
@spaces.GPU(duration=180)
def decode_audio(encoded_file_path):
try:
# Load encoded data from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
compressed_data = temp_file.read()
z_numpy_bytes = lz4.frame.decompress(compressed_data)
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
z = torch.from_numpy(z_numpy).to(torch_device)
# Decode the audio
with torch.no_grad():
reconstructed_audio = agc.decode(z)
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate)
return temp_wav_path
except Exception as e:
return f"Decoding error: {e}"
@spaces.GPU(duration=180)
def stream_decode_audio(encoded_file_path) -> Generator[np.ndarray, None, None]:
try:
# Load encoded data from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
compressed_data = temp_file.read()
z_numpy_bytes = lz4.frame.decompress(compressed_data)
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
z = torch.from_numpy(z_numpy).to(torch_device)
# Decode the audio in chunks
chunk_size = 16000 # 1 second of audio at 16kHz
with torch.no_grad():
for i in range(0, z.shape[2], chunk_size):
z_chunk = z[:, :, i:i+chunk_size]
audio_chunk = agc.decode(z_chunk)
yield audio_chunk.squeeze(0).cpu().numpy()
except Exception as e:
yield np.zeros((2, chunk_size)) # Return silence in case of error
print(f"Streaming decoding error: {e}")
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with AGC (GPU/CPU)")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath")
encode_button = gr.Button("Encode")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
decode_button = gr.Button("Decode")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
stream_button = gr.Button("Start Streaming")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
demo.queue().launch()