import gradio as gr import torch import torchaudio from agc import AGC import tempfile import numpy as np import lz4.frame import os from typing import Generator import spaces # Attempt to use GPU, fallback to CPU try: torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {torch_device}") except Exception as e: print(f"Error detecting GPU. Using CPU. Error: {e}") torch_device = torch.device("cpu") # Load the AGC model def load_agc_model(): return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device) agc = load_agc_model() @spaces.GPU(duration=180) def encode_audio(audio_file_path): try: # Load the audio file waveform, sample_rate = torchaudio.load(audio_file_path) # Encode the audio audio = waveform.unsqueeze(0).to(torch_device) with torch.no_grad(): z = agc.encode(audio) # Convert to NumPy and save to a temporary .owie file z_numpy = z.detach().cpu().numpy() temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie") os.close(temp_fd) # Close the file descriptor to avoid issues with os.fdopen with open(temp_file_path, 'wb') as temp_file: # Store the sample rate as the first 4 bytes temp_file.write(sample_rate.to_bytes(4, byteorder='little')) # Compress and write the encoded data compressed_data = lz4.frame.compress(z_numpy.tobytes()) temp_file.write(compressed_data) return temp_file_path except Exception as e: return f"Encoding error: {e}" @spaces.GPU(duration=180) def decode_audio(encoded_file_path): try: # Load encoded data and sample rate from the .owie file with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') compressed_data = temp_file.read() z_numpy_bytes = lz4.frame.decompress(compressed_data) z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1) z = torch.from_numpy(z_numpy).to(torch_device) # Decode the audio with torch.no_grad(): reconstructed_audio = agc.decode(z) # Save to a temporary WAV file temp_wav_path = tempfile.mktemp(suffix=".wav") torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate) return temp_wav_path except Exception as e: return f"Decoding error: {e}" @spaces.GPU(duration=180) def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]: try: # Load encoded data and sample rate from the .owie file with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') compressed_data = temp_file.read() z_numpy_bytes = lz4.frame.decompress(compressed_data) z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1) z = torch.from_numpy(z_numpy).to(torch_device) # Decode the audio in chunks chunk_size = sample_rate # Use the stored sample rate as chunk size with torch.no_grad(): for i in range(0, z.shape[2], chunk_size): z_chunk = z[:, :, i:i+chunk_size] audio_chunk = agc.decode(z_chunk) # Convert to numpy array and transpose audio_data = audio_chunk.squeeze(0).cpu().numpy().T yield (sample_rate, audio_data) except Exception as e: print(f"Streaming decoding error: {e}") yield (sample_rate, np.zeros((chunk_size, 32), dtype=np.float32)) # Return silence # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Audio Compression with AGC (GPU/CPU)") with gr.Tab("Encode"): input_audio = gr.Audio(label="Input Audio", type="filepath") encode_button = gr.Button("Encode") encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output) with gr.Tab("Decode"): input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") decode_button = gr.Button("Decode") decoded_output = gr.Audio(label="Decoded Audio", type="filepath") decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output) with gr.Tab("Streaming"): input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") stream_button = gr.Button("Start Streaming") audio_output = gr.Audio(label="Streaming Audio Output", streaming=True) stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output) demo.queue().launch()