import gradio as gr import spaces import torch import torchaudio from semanticodec import SemantiCodec import tempfile import numpy as np import lz4.frame import os from typing import Generator import asyncio # Import asyncio for cancellation import traceback # Import traceback for error handling # Attempt to use GPU, fallback to CPU try: torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {torch_device}") except Exception as e: print(f"Error detecting GPU. Using CPU. Error: {e}") torch_device = torch.device("cpu") # Load the SemantiCodec model semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device) # Global variables for cancellation cancel_encode = False cancel_decode = False cancel_stream = False @spaces.GPU(duration=30) # Changed from 250 to 30 def encode_audio(audio_file_path): global cancel_encode if audio_file_path is None: print("No audio file provided") return None try: # Load the audio file waveform, sample_rate = torchaudio.load(audio_file_path) # Ensure waveform has the correct number of dimensions if waveform.ndim == 1: waveform = waveform.unsqueeze(0) # Save to a temporary WAV file temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav") os.close(temp_wav_fd) torchaudio.save(temp_wav_file_path, waveform, sample_rate) # Encode the audio tokens = semanticodec.encode(temp_wav_file_path) # Convert tokens to NumPy and save to .owie file tokens_numpy = tokens.detach().cpu().numpy() print(f"Original tokens shape: {tokens_numpy.shape}") # Ensure tokens_numpy is 2D if tokens_numpy.ndim == 1: tokens_numpy = tokens_numpy.reshape(1, -1) elif tokens_numpy.ndim == 2: pass # Already 2D elif tokens_numpy.ndim == 3 and tokens_numpy.shape[0] == 1: tokens_numpy = tokens_numpy.squeeze(0) else: raise ValueError(f"Unexpected tokens array shape: {tokens_numpy.shape}") print(f"Reshaped tokens shape: {tokens_numpy.shape}") # Create temporary .owie file temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie") os.close(temp_fd) with open(temp_file_path, 'wb') as temp_file: # Write sample rate temp_file.write(sample_rate.to_bytes(4, byteorder='little')) # Compress and write the tokens data compressed_data = lz4.frame.compress(tokens_numpy.tobytes()) temp_file.write(compressed_data) return temp_file_path except Exception as e: print(f"Encoding error: {e}") return None # Return None instead of the error message finally: cancel_encode = False # Reset cancel flag after encoding if 'temp_wav_file_path' in locals(): os.remove(temp_wav_file_path) # Clean up temporary WAV file # Add this function to handle the output def handle_encode_output(file_path): if file_path is None: return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True) return file_path, gr.Markdown(visible=False) @spaces.GPU(duration=30) # Changed from 250 to 30 def decode_audio(encoded_file_path): global cancel_decode try: # Load encoded data and sample rate with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') compressed_data = temp_file.read() tokens_numpy_bytes = lz4.frame.decompress(compressed_data) tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Reshape tokens to match the original shape tokens_numpy = tokens_numpy.reshape(1, -1, 2) tokens = torch.from_numpy(tokens_numpy).to(torch_device) # Debugging prints to check tensor shapes print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}") # Decode the audio with torch.no_grad(): waveform = semanticodec.decode(tokens) # Save to a temporary WAV file temp_wav_path = tempfile.mktemp(suffix=".wav") torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate) return temp_wav_path except Exception as e: print(f"Decoding error: {e}") print(f"Traceback: {traceback.format_exc()}") return str(e) # Return error message as string finally: cancel_decode = False # Reset cancel flag after decoding @spaces.GPU(duration=30) # Changed from 250 to 30 async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]: global cancel_stream try: # Load encoded data and sample rate from the .owie file with open(encoded_file_path, 'rb') as temp_file: sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') compressed_data = temp_file.read() tokens_numpy_bytes = lz4.frame.decompress(compressed_data) tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) tokens_numpy = tokens_numpy.reshape(1, -1, 2) tokens = torch.from_numpy(tokens_numpy).to(torch_device) print(f"Streaming tokens shape: {tokens.shape}, dtype: {tokens.dtype}") # Decode the audio in chunks chunk_size = sample_rate // 2 # Adjust chunk size to account for the new shape with torch.no_grad(): for i in range(0, tokens.shape[1], chunk_size): if cancel_stream: break # Exit the loop if cancellation is requested tokens_chunk = tokens[:, i:i+chunk_size, :] audio_chunk = semanticodec.decode(tokens_chunk) # Convert to numpy array and transpose audio_data = audio_chunk.squeeze(0).cpu().numpy().T yield (sample_rate, audio_data) await asyncio.sleep(0) # Allow for cancellation check except Exception as e: print(f"Streaming decoding error: {e}") print(f"Traceback: {traceback.format_exc()}") yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence finally: cancel_stream = False # Reset cancel flag after streaming # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)") with gr.Tab("Encode"): input_audio = gr.Audio(label="Input Audio", type="filepath") # Using "filepath" mode encode_button = gr.Button("Encode") cancel_encode_button = gr.Button("Cancel") encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode encode_error_message = gr.Markdown(visible=False) def encode_wrapper(audio): if audio is None: return None, gr.Markdown("Please upload an audio file before encoding.", visible=True) return handle_encode_output(encode_audio(audio)) encode_button.click( encode_wrapper, inputs=input_audio, outputs=[encoded_output, encode_error_message] ) cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None) # Set cancel_encode flag with gr.Tab("Decode"): input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode decode_button = gr.Button("Decode") cancel_decode_button = gr.Button("Cancel") decoded_output = gr.Audio(label="Decoded Audio", type="filepath") # Using "filepath" mode decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output) cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None) # Set cancel_decode flag with gr.Tab("Streaming"): input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode stream_button = gr.Button("Start Streaming") cancel_stream_button = gr.Button("Cancel") audio_output = gr.Audio(label="Streaming Audio Output", streaming=True) stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output) cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None) # Set cancel_stream flag demo.queue().launch()