|
import gradio as gr |
|
import spaces |
|
import torch |
|
import torchaudio |
|
from semanticodec import SemantiCodec |
|
import tempfile |
|
import numpy as np |
|
import lz4.frame |
|
import os |
|
from typing import Generator |
|
import asyncio |
|
|
|
|
|
try: |
|
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {torch_device}") |
|
except Exception as e: |
|
print(f"Error detecting GPU. Using CPU. Error: {e}") |
|
torch_device = torch.device("cpu") |
|
|
|
|
|
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device) |
|
|
|
|
|
cancel_encode = False |
|
cancel_decode = False |
|
cancel_stream = False |
|
|
|
@spaces.GPU(duration=250) |
|
def encode_audio(audio_file_path): |
|
global cancel_encode |
|
|
|
waveform, sample_rate = torchaudio.load(audio_file_path) |
|
|
|
|
|
if waveform.ndim == 1: |
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav") |
|
os.close(temp_wav_fd) |
|
torchaudio.save(temp_wav_file_path, waveform, sample_rate) |
|
|
|
|
|
tokens = semanticodec.encode(temp_wav_file_path) |
|
|
|
|
|
tokens_numpy = tokens.detach().cpu().numpy() |
|
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie") |
|
os.close(temp_fd) |
|
with open(temp_file_path, 'wb') as temp_file: |
|
|
|
temp_file.write(sample_rate.to_bytes(4, byteorder='little')) |
|
|
|
compressed_data = lz4.frame.compress(tokens_numpy.tobytes()) |
|
temp_file.write(compressed_data) |
|
|
|
return temp_file_path |
|
|
|
@spaces.GPU(duration=250) |
|
def decode_audio(encoded_file_path): |
|
global cancel_decode |
|
|
|
with open(encoded_file_path, 'rb') as temp_file: |
|
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') |
|
compressed_data = temp_file.read() |
|
tokens_numpy_bytes = lz4.frame.decompress(compressed_data) |
|
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) |
|
|
|
|
|
if tokens_numpy.ndim == 1: |
|
tokens_numpy = tokens_numpy.reshape(1, -1, 1) |
|
|
|
tokens = torch.from_numpy(tokens_numpy).to(torch_device) |
|
|
|
|
|
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}") |
|
|
|
|
|
with torch.no_grad(): |
|
waveform = semanticodec.decode(tokens) |
|
|
|
|
|
temp_wav_path = tempfile.mktemp(suffix=".wav") |
|
torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate) |
|
return temp_wav_path |
|
|
|
@spaces.GPU(duration=250) |
|
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]: |
|
global cancel_stream |
|
try: |
|
|
|
with open(encoded_file_path, 'rb') as temp_file: |
|
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little') |
|
compressed_data = temp_file.read() |
|
tokens_numpy_bytes = lz4.frame.decompress(compressed_data) |
|
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64) |
|
|
|
|
|
if tokens_numpy.ndim == 1: |
|
tokens_numpy = tokens_numpy.reshape(1, -1, 1) |
|
|
|
tokens = torch.from_numpy(tokens_numpy).to(torch_device) |
|
|
|
|
|
if tokens.ndimension() == 2: |
|
tokens = tokens.unsqueeze(0) |
|
|
|
|
|
chunk_size = sample_rate |
|
with torch.no_grad(): |
|
for i in range(0, tokens.shape[1], chunk_size): |
|
if cancel_stream: |
|
break |
|
|
|
tokens_chunk = tokens[:, i:i+chunk_size] |
|
audio_chunk = semanticodec.decode(tokens_chunk) |
|
|
|
audio_data = audio_chunk.squeeze(0).cpu().numpy().T |
|
yield (sample_rate, audio_data) |
|
await asyncio.sleep(0) |
|
|
|
except Exception as e: |
|
print(f"Streaming decoding error: {e}") |
|
yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) |
|
|
|
finally: |
|
cancel_stream = False |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)") |
|
|
|
with gr.Tab("Encode"): |
|
input_audio = gr.Audio(label="Input Audio", type="filepath") |
|
encode_button = gr.Button("Encode") |
|
cancel_encode_button = gr.Button("Cancel") |
|
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") |
|
|
|
encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output) |
|
cancel_encode_button.click(lambda: globals().update(cancel_encode=True), |
|
outputs=None) |
|
|
|
with gr.Tab("Decode"): |
|
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") |
|
decode_button = gr.Button("Decode") |
|
cancel_decode_button = gr.Button("Cancel") |
|
decoded_output = gr.Audio(label="Decoded Audio", type="filepath") |
|
|
|
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output) |
|
cancel_decode_button.click(lambda: globals().update(cancel_decode=True), |
|
outputs=None) |
|
|
|
with gr.Tab("Streaming"): |
|
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") |
|
stream_button = gr.Button("Start Streaming") |
|
cancel_stream_button = gr.Button("Cancel") |
|
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True) |
|
|
|
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output) |
|
cancel_stream_button.click(lambda: globals().update(cancel_stream=True), |
|
outputs=None) |
|
|
|
demo.queue().launch() |
|
|