import gradio as gr import torch import torchaudio from agc import AGC import tempfile import numpy as np import lz4.frame import os from typing import Generator import spaces # Attempt to use GPU, fallback to CPU try: torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {torch_device}") except Exception as e: print(f"Error detecting GPU. Using CPU. Error: {e}") torch_device = torch.device("cpu") # Load the AGC model @spaces.GPU(duration=180) def load_agc_model(): return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device) agc = load_agc_model() @spaces.GPU(duration=180) def encode_audio(audio_file_path): try: # Load the audio file waveform, sample_rate = torchaudio.load(audio_file_path) # Convert to stereo if necessary if waveform.size(0) == 1: waveform = waveform.repeat(2, 1) # Encode the audio audio = waveform.unsqueeze(0).to(torch_device) with torch.no_grad(): z = agc.encode(audio) # Convert to NumPy and save to a temporary .owie file z_numpy = z.detach().cpu().numpy() temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie") os.close(temp_fd) # Close the file descriptor to avoid issues with os.fdopen with open(temp_file_path, 'wb') as temp_file: compressed_data = lz4.frame.compress(z_numpy.tobytes()) temp_file.write(compressed_data) return temp_file_path except Exception as e: return f"Encoding error: {e}" @spaces.GPU(duration=180) def decode_audio(encoded_file_path): try: # Load encoded data from the .owie file with open(encoded_file_path, 'rb') as temp_file: compressed_data = temp_file.read() z_numpy_bytes = lz4.frame.decompress(compressed_data) z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1) z = torch.from_numpy(z_numpy).to(torch_device) # Decode the audio with torch.no_grad(): reconstructed_audio = agc.decode(z) # Save to a temporary WAV file temp_wav_path = tempfile.mktemp(suffix=".wav") torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate) return temp_wav_path except Exception as e: return f"Decoding error: {e}" @spaces.GPU(duration=180) def stream_decode_audio(encoded_file_path) -> Generator[np.ndarray, None, None]: try: # Load encoded data from the .owie file with open(encoded_file_path, 'rb') as temp_file: compressed_data = temp_file.read() z_numpy_bytes = lz4.frame.decompress(compressed_data) z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1) z = torch.from_numpy(z_numpy).to(torch_device) # Decode the audio in chunks chunk_size = 16000 # 1 second of audio at 16kHz with torch.no_grad(): for i in range(0, z.shape[2], chunk_size): z_chunk = z[:, :, i:i+chunk_size] audio_chunk = agc.decode(z_chunk) yield audio_chunk.squeeze(0).cpu().numpy() except Exception as e: yield np.zeros((2, chunk_size)) # Return silence in case of error print(f"Streaming decoding error: {e}") # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Audio Compression with AGC (GPU/CPU)") with gr.Tab("Encode"): input_audio = gr.Audio(label="Input Audio", type="filepath") encode_button = gr.Button("Encode") encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output) with gr.Tab("Decode"): input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") decode_button = gr.Button("Decode") decoded_output = gr.Audio(label="Decoded Audio", type="filepath") decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output) with gr.Tab("Streaming"): input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") stream_button = gr.Button("Start Streaming") audio_output = gr.Audio(label="Streaming Audio Output", streaming=True) stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output) demo.queue().launch()