import gradio as gr import spaces import torch import torchaudio from semanticodec import SemantiCodec import tempfile import numpy as np import lz4.frame import os from typing import Generator import asyncio # Import asyncio for cancellation import traceback # Import traceback for error handling # Attempt to use GPU, fallback to CPU try: torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {torch_device}") except Exception as e: print(f"Error detecting GPU. Using CPU. Error: {e}") torch_device = torch.device("cpu") # Load the SemantiCodec model semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device) # Global variables for cancellation cancel_encode = False cancel_decode = False cancel_stream = False @spaces.GPU(duration=30) def encode_audio(audio_file_path): global cancel_encode if audio_file_path is None: print("No audio file provided") return None try: # Load the audio file waveform, sample_rate = torchaudio.load(audio_file_path) # Ensure waveform has the correct number of dimensions if waveform.ndim == 1: waveform = waveform.unsqueeze(0) # Save to a temporary WAV file temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav") os.close(temp_wav_fd) torchaudio.save(temp_wav_file_path, waveform, sample_rate) # Encode the audio tokens = semanticodec.encode(temp_wav_file_path) # Convert tokens to NumPy tokens_numpy = tokens.detach().cpu().numpy() print(f"Tokens shape: {tokens_numpy.shape}") # Create temporary .owie file temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie") os.close(temp_fd) with open(temp_file_path, 'wb') as temp_file: # Write sample rate temp_file.write(sample_rate.to_bytes(4, byteorder='little')) # Write shape information temp_file.write(len(tokens_numpy.shape).to_bytes(4, byteorder='little')) for dim in tokens_numpy.shape: temp_file.write(dim.to_bytes(4, byteorder='little')) # Compress and write the tokens data compressed_data = lz4.frame.compress(tokens_numpy.tobytes()) temp_file.write(compressed_data) return temp_file_path except Exception as e: print(f"Encoding error: {e}") return None finally: cancel_encode = False if 'temp_wav_file_path' in locals(): os.remove(temp_wav_file_path) # Add this function to handle the output def handle_encode_output(file_path): if file_path is None: return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True) return file_path, gr.Markdown(visible=False) @spaces.GPU(duration=30) def decode_audio(encoded_file_path): global cancel_decode try: # Load encoded data and sample rate with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') ndim = int.from_bytes(temp_file.read(4), byteorder='little') shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim)) compressed_data = temp_file.read() tokens_numpy_bytes = lz4.frame.decompress(compressed_data) tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape) # Move the tensor to the same device as the model tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device) print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}") print(f"Model device: {semanticodec.device}") # Decode the audio with torch.no_grad(): waveform = semanticodec.decode(tokens) # Save to a temporary WAV file temp_wav_path = tempfile.mktemp(suffix=".wav") torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate) return temp_wav_path except Exception as e: print(f"Decoding error: {e}") print(f"Traceback: {traceback.format_exc()}") return str(e) finally: cancel_decode = False @spaces.GPU(duration=30) async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]: global cancel_stream try: # Load encoded data and sample rate from the .owie file with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') ndim = int.from_bytes(temp_file.read(4), byteorder='little') shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim)) compressed_data = temp_file.read() tokens_numpy_bytes = lz4.frame.decompress(compressed_data) tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape) # Move the tensor to the same device as the model tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device) print(f"Streaming tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}") print(f"Model device: {semanticodec.device}") # Decode the audio in chunks chunk_size = sample_rate // 2 # Adjust chunk size to account for the new shape with torch.no_grad(): for i in range(0, tokens.shape[1], chunk_size): if cancel_stream: break # Exit the loop if cancellation is requested tokens_chunk = tokens[:, i:i+chunk_size, :] audio_chunk = semanticodec.decode(tokens_chunk) # Convert to numpy array and transpose audio_data = audio_chunk.squeeze(0).cpu().numpy().T yield (sample_rate, audio_data) await asyncio.sleep(0) # Allow for cancellation check except Exception as e: print(f"Streaming decoding error: {e}") print(f"Traceback: {traceback.format_exc()}") yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence finally: cancel_stream = False # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)") with gr.Tab("Encode"): input_audio = gr.Audio(label="Input Audio", type="filepath") encode_button = gr.Button("Encode") cancel_encode_button = gr.Button("Cancel") encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") encode_error_message = gr.Markdown(visible=False) def encode_wrapper(audio): if audio is None: return None, gr.Markdown("Please upload an audio file before encoding.", visible=True) return handle_encode_output(encode_audio(audio)) encode_button.click( encode_wrapper, inputs=input_audio, outputs=[encoded_output, encode_error_message] ) cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None) with gr.Tab("Decode"): input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") decode_button = gr.Button("Decode") cancel_decode_button = gr.Button("Cancel") decoded_output = gr.Audio(label="Decoded Audio", type="filepath") decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output) cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None) with gr.Tab("Streaming"): input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") stream_button = gr.Button("Start Streaming") cancel_stream_button = gr.Button("Cancel") audio_output = gr.Audio(label="Streaming Audio Output", streaming=True) stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output) cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None) demo.queue().launch()