|
import gradio as gr |
|
import torch |
|
import torchaudio |
|
from agc import AGC |
|
import tempfile |
|
import numpy as np |
|
import lz4.frame |
|
import os |
|
from typing import Generator |
|
import spaces |
|
|
|
|
|
try: |
|
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {torch_device}") |
|
except Exception as e: |
|
print(f"Error detecting GPU. Using CPU. Error: {e}") |
|
torch_device = torch.device("cpu") |
|
|
|
|
|
def load_agc_model(): |
|
return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device) |
|
|
|
agc = load_agc_model() |
|
|
|
@spaces.GPU(duration=180) |
|
def encode_audio(audio_file_path): |
|
try: |
|
|
|
waveform, sample_rate = torchaudio.load(audio_file_path) |
|
|
|
|
|
if waveform.size(0) == 1: |
|
waveform = waveform.repeat(2, 1) |
|
|
|
|
|
audio = waveform.unsqueeze(0).to(torch_device) |
|
with torch.no_grad(): |
|
z = agc.encode(audio) |
|
|
|
|
|
z_numpy = z.detach().cpu().numpy() |
|
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie") |
|
os.close(temp_fd) |
|
with open(temp_file_path, 'wb') as temp_file: |
|
compressed_data = lz4.frame.compress(z_numpy.tobytes()) |
|
temp_file.write(compressed_data) |
|
|
|
return temp_file_path |
|
|
|
except Exception as e: |
|
return f"Encoding error: {e}" |
|
|
|
@spaces.GPU(duration=180) |
|
def decode_audio(encoded_file_path): |
|
try: |
|
|
|
with open(encoded_file_path, 'rb') as temp_file: |
|
compressed_data = temp_file.read() |
|
z_numpy_bytes = lz4.frame.decompress(compressed_data) |
|
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1) |
|
z = torch.from_numpy(z_numpy).to(torch_device) |
|
|
|
|
|
with torch.no_grad(): |
|
reconstructed_audio = agc.decode(z) |
|
|
|
|
|
temp_wav_path = tempfile.mktemp(suffix=".wav") |
|
torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate) |
|
return temp_wav_path |
|
|
|
except Exception as e: |
|
return f"Decoding error: {e}" |
|
|
|
@spaces.GPU(duration=180) |
|
def stream_decode_audio(encoded_file_path) -> Generator[np.ndarray, None, None]: |
|
try: |
|
|
|
with open(encoded_file_path, 'rb') as temp_file: |
|
compressed_data = temp_file.read() |
|
z_numpy_bytes = lz4.frame.decompress(compressed_data) |
|
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1) |
|
z = torch.from_numpy(z_numpy).to(torch_device) |
|
|
|
|
|
chunk_size = 16000 |
|
with torch.no_grad(): |
|
for i in range(0, z.shape[2], chunk_size): |
|
z_chunk = z[:, :, i:i+chunk_size] |
|
audio_chunk = agc.decode(z_chunk) |
|
yield audio_chunk.squeeze(0).cpu().numpy() |
|
|
|
except Exception as e: |
|
yield np.zeros((2, chunk_size)) |
|
print(f"Streaming decoding error: {e}") |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Audio Compression with AGC (GPU/CPU)") |
|
|
|
with gr.Tab("Encode"): |
|
input_audio = gr.Audio(label="Input Audio", type="filepath") |
|
encode_button = gr.Button("Encode") |
|
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") |
|
|
|
encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output) |
|
|
|
with gr.Tab("Decode"): |
|
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") |
|
decode_button = gr.Button("Decode") |
|
decoded_output = gr.Audio(label="Decoded Audio", type="filepath") |
|
|
|
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output) |
|
|
|
with gr.Tab("Streaming"): |
|
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") |
|
stream_button = gr.Button("Start Streaming") |
|
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True) |
|
|
|
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output) |
|
|
|
demo.queue().launch() |