File size: 7,640 Bytes
f18f98b 6fa15d7 eb0f782 763a29b a2cc897 bd40662 eb0f782 763a29b c707064 bd40662 eb0f782 c27eb74 763a29b a7dbbfe 50a007c 763a29b c9a89ac cefd33c eb0f782 763a29b 841c4fd 306e4c8 841c4fd cefd33c 841c4fd cefd33c 841c4fd cefd33c 841c4fd cefd33c 73fdeaf cefd33c 73fdeaf 306e4c8 73fdeaf 763a29b cefd33c eb0f782 763a29b 4ca3581 841c4fd cefd33c 841c4fd ff92e4e d87908f 841c4fd 4ca3581 841c4fd cefd33c 841c4fd 4ca3581 841c4fd c707064 cefd33c 841c4fd cefd33c 763a29b cefd33c 763a29b 841c4fd eb0f782 d24bcbd eb0f782 d24bcbd eb0f782 763a29b cefd33c ff92e4e c707064 eb0f782 cefd33c eb0f782 cefd33c 763a29b cefd33c 763a29b d888fa7 763a29b c27eb74 eb0f782 c707064 763a29b c27eb74 763a29b cefd33c d24bcbd eb0f782 d07be48 763a29b 44bab11 f18f98b cefd33c eb0f782 763a29b cefd33c 73fdeaf 44bab11 306e4c8 73fdeaf 306e4c8 73fdeaf cefd33c 44bab11 f18f98b cefd33c eb0f782 763a29b cefd33c eb0f782 cefd33c eb0f782 cefd33c eb0f782 763a29b d888fa7 44bab11 eb0f782 cefd33c f18f98b 73fdeaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio # Import asyncio for cancellation
import traceback # Import traceback for error handling
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)
# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False
@spaces.GPU(duration=30)
def encode_audio(audio_file_path):
global cancel_encode
if audio_file_path is None:
print("No audio file provided")
return None
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Ensure waveform has the correct number of dimensions
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Save to a temporary WAV file
temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
os.close(temp_wav_fd)
torchaudio.save(temp_wav_file_path, waveform, sample_rate)
# Encode the audio
tokens = semanticodec.encode(temp_wav_file_path)
# Convert tokens to NumPy
tokens_numpy = tokens.detach().cpu().numpy()
print(f"Tokens shape: {tokens_numpy.shape}")
# Create temporary .owie file
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd)
with open(temp_file_path, 'wb') as temp_file:
# Write sample rate
temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
# Compress and write the tokens data
compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
print(f"Encoding error: {e}")
return None
finally:
cancel_encode = False
if 'temp_wav_file_path' in locals():
os.remove(temp_wav_file_path)
# Add this function to handle the output
def handle_encode_output(file_path):
if file_path is None:
return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
return file_path, gr.Markdown(visible=False)
@spaces.GPU(duration=30)
def decode_audio(encoded_file_path):
global cancel_decode
try:
# Load encoded data and sample rate
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(-1)
# Move the tensor to the same device as the model
tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
print(f"Model device: {semanticodec.device}")
# Decode the audio
with torch.no_grad():
waveform = semanticodec.decode(tokens)
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
return temp_wav_path
except Exception as e:
print(f"Decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
return str(e)
finally:
cancel_decode = False
@spaces.GPU(duration=30)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
global cancel_stream
try:
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(-1)
# Move the tensor to the same device as the model
tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
print(f"Streaming tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
print(f"Model device: {semanticodec.device}")
# Decode the audio in chunks
chunk_size = sample_rate * 2 # Adjust chunk size as needed
with torch.no_grad():
for i in range(0, tokens.shape[0], chunk_size):
if cancel_stream:
break # Exit the loop if cancellation is requested
tokens_chunk = tokens[i:i+chunk_size]
audio_chunk = semanticodec.decode(tokens_chunk)
# Convert to numpy array and transpose
audio_data = audio_chunk.squeeze(0).cpu().numpy().T
yield (sample_rate, audio_data)
await asyncio.sleep(0) # Allow for cancellation check
except Exception as e:
print(f"Streaming decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
finally:
cancel_stream = False
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath")
encode_button = gr.Button("Encode")
cancel_encode_button = gr.Button("Cancel")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
encode_error_message = gr.Markdown(visible=False)
def encode_wrapper(audio):
if audio is None:
return None, gr.Markdown("Please upload an audio file before encoding.", visible=True)
return handle_encode_output(encode_audio(audio))
encode_button.click(
encode_wrapper,
inputs=input_audio,
outputs=[encoded_output, encode_error_message]
)
cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
decode_button = gr.Button("Decode")
cancel_decode_button = gr.Button("Cancel")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
stream_button = gr.Button("Start Streaming")
cancel_stream_button = gr.Button("Cancel")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)
demo.queue().launch() |