File size: 6,677 Bytes
f18f98b 6fa15d7 eb0f782 763a29b a2cc897 bd40662 eb0f782 763a29b bd40662 eb0f782 c27eb74 763a29b a7dbbfe 763a29b c9a89ac 6eabaea eb0f782 763a29b 4ca3581 763a29b 6eabaea eb0f782 763a29b 4ca3581 d87908f 4ca3581 763a29b 6eabaea 763a29b eb0f782 d24bcbd eb0f782 d24bcbd eb0f782 763a29b 4ca3581 763a29b eb0f782 47c28f2 d87908f 47c28f2 eb0f782 d24bcbd eb0f782 763a29b d888fa7 763a29b c27eb74 eb0f782 763a29b c27eb74 763a29b d24bcbd eb0f782 d07be48 763a29b 44bab11 f18f98b 6eabaea eb0f782 763a29b 6eabaea 44bab11 eb0f782 763a29b 44bab11 f18f98b 6eabaea eb0f782 763a29b 6eabaea eb0f782 763a29b eb0f782 6eabaea eb0f782 763a29b d888fa7 44bab11 eb0f782 763a29b f18f98b 44da512 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio # Import asyncio for cancellation
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)
# Global variable for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False
@spaces.GPU(duration=250)
def encode_audio(audio_file_path):
global cancel_encode
# Load the audio file and convert it to WAV format
waveform, sample_rate = torchaudio.load(audio_file_path)
# Ensure waveform has the right dimensions
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Create a temporary WAV file
temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
os.close(temp_wav_fd)
torchaudio.save(temp_wav_file_path, waveform, sample_rate)
# Encode the audio using the WAV file path
tokens = semanticodec.encode(temp_wav_file_path)
# Convert to NumPy and save to a temporary .owie file
tokens_numpy = tokens.detach().cpu().numpy()
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd)
with open(temp_file_path, 'wb') as temp_file:
# Store the sample rate as the first 4 bytes
temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
# Compress and write the encoded data
compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
@spaces.GPU(duration=250)
def decode_audio(encoded_file_path):
global cancel_decode
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
# Ensure tokens has the right dimensions
if tokens.ndimension() == 2: # If tokens have only 2 dimensions
tokens = tokens.unsqueeze(0) # Add batch dimension
# Debugging prints to check tensor shapes
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")
# Decode the audio
with torch.no_grad():
waveform = semanticodec.decode(tokens)
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
return temp_wav_path
@spaces.GPU(duration=250)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
global cancel_stream
try:
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) # Ensure dtype matches encoder's output
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
# Ensure tokens has the right dimensions
if tokens.ndimension() == 2: # If tokens have only 2 dimensions
tokens = tokens.unsqueeze(0) # Add batch dimension
# Decode the audio in chunks
chunk_size = sample_rate # Use the stored sample rate as chunk size
with torch.no_grad():
for i in range(0, tokens.shape[1], chunk_size):
if cancel_stream:
break # Exit the loop if cancellation is requested
tokens_chunk = tokens[:, i:i+chunk_size]
audio_chunk = semanticodec.decode(tokens_chunk)
# Convert to numpy array and transpose
audio_data = audio_chunk.squeeze(0).cpu().numpy().T
yield (sample_rate, audio_data)
await asyncio.sleep(0) # Allow for cancellation check
except Exception as e:
print(f"Streaming decoding error: {e}")
yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
finally:
cancel_stream = False # Reset cancel flag after streaming
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath") # Using "filepath" mode
encode_button = gr.Button("Encode")
cancel_encode_button = gr.Button("Cancel")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
cancel_encode_button.click(lambda: globals().update(cancel_encode=True),
outputs=None) # Set cancel_encode flag
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
decode_button = gr.Button("Decode")
cancel_decode_button = gr.Button("Cancel")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath") # Using "filepath" mode
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
cancel_decode_button.click(lambda: globals().update(cancel_decode=True),
outputs=None) # Set cancel_decode flag
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
stream_button = gr.Button("Start Streaming")
cancel_stream_button = gr.Button("Cancel")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
cancel_stream_button.click(lambda: globals().update(cancel_stream=True),
outputs=None) # Set cancel_stream flag
demo.queue().launch()
|