import gradio as gr import spaces import torch import torchaudio from semanticodec import SemantiCodec import tempfile import numpy as np import lz4.frame import os from typing import Generator import asyncio # Import asyncio for cancellation # Attempt to use GPU, fallback to CPU try: torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {torch_device}") except Exception as e: print(f"Error detecting GPU. Using CPU. Error: {e}") torch_device = torch.device("cpu") # Load the SemantiCodec model semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device) # Global variable for cancellation cancel_encode = False cancel_decode = False cancel_stream = False @spaces.GPU(duration=250) def encode_audio(audio_file_path): global cancel_encode # Load the audio file and convert it to WAV format waveform, sample_rate = torchaudio.load(audio_file_path) # Ensure waveform has the right dimensions if waveform.ndim == 1: waveform = waveform.unsqueeze(0) # Create a temporary WAV file temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav") os.close(temp_wav_fd) torchaudio.save(temp_wav_file_path, waveform, sample_rate) # Encode the audio using the WAV file path tokens = semanticodec.encode(temp_wav_file_path) # Convert to NumPy and save to a temporary .owie file tokens_numpy = tokens.detach().cpu().numpy() temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie") os.close(temp_fd) with open(temp_file_path, 'wb') as temp_file: # Store the sample rate as the first 4 bytes temp_file.write(sample_rate.to_bytes(4, byteorder='little')) # Compress and write the encoded data compressed_data = lz4.frame.compress(tokens_numpy.tobytes()) temp_file.write(compressed_data) return temp_file_path @spaces.GPU(duration=250) def decode_audio(encoded_file_path): global cancel_decode # Load encoded data and sample rate from the .owie file with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') compressed_data = temp_file.read() tokens_numpy_bytes = lz4.frame.decompress(compressed_data) tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output # Check if tokens are 1D and reshape to 3D if tokens_numpy.ndim == 1: tokens_numpy = tokens_numpy.reshape(1, -1, 1) # Reshape to [batch_size, token_length, 1] tokens = torch.from_numpy(tokens_numpy).to(torch_device) # Debugging prints to check tensor shapes print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}") # Decode the audio with torch.no_grad(): waveform = semanticodec.decode(tokens) # Save to a temporary WAV file temp_wav_path = tempfile.mktemp(suffix=".wav") torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate) return temp_wav_path @spaces.GPU(duration=250) async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]: global cancel_stream try: # Load encoded data and sample rate from the .owie file with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') compressed_data = temp_file.read() tokens_numpy_bytes = lz4.frame.decompress(compressed_data) tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output # Check if tokens are 1D and reshape to 3D if tokens_numpy.ndim == 1: tokens_numpy = tokens_numpy.reshape(1, -1, 1) # Reshape to [batch_size, token_length, 1] tokens = torch.from_numpy(tokens_numpy).to(torch_device) # Ensure tokens has the right dimensions if tokens.ndimension() == 2: # If tokens have only 2 dimensions tokens = tokens.unsqueeze(0) # Add batch dimension # Decode the audio in chunks chunk_size = sample_rate # Use the stored sample rate as chunk size with torch.no_grad(): for i in range(0, tokens.shape[1], chunk_size): if cancel_stream: break # Exit the loop if cancellation is requested tokens_chunk = tokens[:, i:i+chunk_size] audio_chunk = semanticodec.decode(tokens_chunk) # Convert to numpy array and transpose audio_data = audio_chunk.squeeze(0).cpu().numpy().T yield (sample_rate, audio_data) await asyncio.sleep(0) # Allow for cancellation check except Exception as e: print(f"Streaming decoding error: {e}") yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence finally: cancel_stream = False # Reset cancel flag after streaming # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)") with gr.Tab("Encode"): input_audio = gr.Audio(label="Input Audio", type="filepath") # Using "filepath" mode encode_button = gr.Button("Encode") cancel_encode_button = gr.Button("Cancel") encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output) cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None) # Set cancel_encode flag with gr.Tab("Decode"): input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode decode_button = gr.Button("Decode") cancel_decode_button = gr.Button("Cancel") decoded_output = gr.Audio(label="Decoded Audio", type="filepath") # Using "filepath" mode decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output) cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None) # Set cancel_decode flag with gr.Tab("Streaming"): input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode stream_button = gr.Button("Start Streaming") cancel_stream_button = gr.Button("Cancel") audio_output = gr.Audio(label="Streaming Audio Output", streaming=True) stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output) cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None) # Set cancel_stream flag demo.queue().launch()