dac / app.py
owiedotch's picture
Update app.py
5920386 verified
raw
history blame contribute delete
No virus
8.6 kB
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio # Import asyncio for cancellation
import traceback # Import traceback for error handling
# Load the SemantiCodec model without specifying a device
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768)
# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False
@spaces.GPU(duration=30)
def encode_audio(audio_file_path):
global cancel_encode
if audio_file_path is None:
print("No audio file provided")
return None
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Ensure waveform has the correct number of dimensions
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Save to a temporary WAV file
temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
os.close(temp_wav_fd)
torchaudio.save(temp_wav_file_path, waveform, sample_rate)
# Encode the audio
tokens = semanticodec.encode(temp_wav_file_path)
# Convert tokens to NumPy
tokens_numpy = tokens.detach().cpu().numpy()
print(f"Tokens shape: {tokens_numpy.shape}")
# Create temporary .owie file
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd)
with open(temp_file_path, 'wb') as temp_file:
# Write sample rate
temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
# Write shape information
temp_file.write(len(tokens_numpy.shape).to_bytes(4, byteorder='little'))
for dim in tokens_numpy.shape:
temp_file.write(dim.to_bytes(4, byteorder='little'))
# Compress and write the tokens data
compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
temp_file.write(len(compressed_data).to_bytes(4, byteorder='little'))
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
print(f"Encoding error: {e}")
return None
finally:
cancel_encode = False
if 'temp_wav_file_path' in locals():
os.remove(temp_wav_file_path)
# Add this function to handle the output
def handle_encode_output(file_path):
if file_path is None:
return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
return file_path, gr.Markdown(visible=False)
@spaces.GPU(duration=30)
def decode_audio(encoded_file_path):
global cancel_decode
try:
# Load encoded data and sample rate
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
ndim = int.from_bytes(temp_file.read(4), byteorder='little')
shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim))
compressed_size = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read(compressed_size)
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape)
# Create a tensor from the numpy array
tokens = torch.from_numpy(tokens_numpy)
# Determine the device of the model
model_device = next(semanticodec.parameters()).device
print(f"Model device: {model_device}")
# Move the tokens to the same device as the model
tokens = tokens.to(model_device)
print(f"Tokens device: {tokens.device}")
# Decode the audio
with torch.no_grad():
waveform = semanticodec.decode(tokens)
print(f"Waveform device: {waveform.device}")
# Move waveform to CPU for saving
waveform_cpu = waveform.cpu()
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, waveform_cpu.squeeze(0), sample_rate)
return temp_wav_path
except Exception as e:
print(f"Decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
return str(e)
finally:
cancel_decode = False
@spaces.GPU(duration=30)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
global cancel_stream
try:
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
ndim = int.from_bytes(temp_file.read(4), byteorder='little')
shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim))
compressed_size = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read(compressed_size)
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape)
# Create a tensor from the numpy array
tokens = torch.from_numpy(tokens_numpy)
# Determine the device of the model
model_device = next(semanticodec.parameters()).device
print(f"Model device: {model_device}")
# Move the tokens to the same device as the model
tokens = tokens.to(model_device)
print(f"Streaming tokens device: {tokens.device}")
# Decode the audio in chunks
chunk_size = sample_rate * 2 # Adjust chunk size as needed
with torch.no_grad():
for i in range(0, tokens.shape[1], chunk_size):
if cancel_stream:
break # Exit the loop if cancellation is requested
tokens_chunk = tokens[:, i:i+chunk_size, :]
audio_chunk = semanticodec.decode(tokens_chunk)
# Convert to numpy array and transpose
audio_data = audio_chunk.squeeze(0).cpu().numpy().T
yield (sample_rate, audio_data)
await asyncio.sleep(0) # Allow for cancellation check
except Exception as e:
print(f"Streaming decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
finally:
cancel_stream = False
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath")
encode_button = gr.Button("Encode")
cancel_encode_button = gr.Button("Cancel")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
encode_error_message = gr.Markdown(visible=False)
def encode_wrapper(audio):
if audio is None:
return None, gr.Markdown("Please upload an audio file before encoding.", visible=True)
return handle_encode_output(encode_audio(audio))
encode_button.click(
encode_wrapper,
inputs=input_audio,
outputs=[encoded_output, encode_error_message]
)
cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
decode_button = gr.Button("Decode")
cancel_decode_button = gr.Button("Cancel")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
stream_button = gr.Button("Start Streaming")
cancel_stream_button = gr.Button("Cancel")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)
demo.queue().launch()