dac / app.py
owiedotch's picture
Update app.py
cefd33c verified
raw
history blame
No virus
7.64 kB
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio # Import asyncio for cancellation
import traceback # Import traceback for error handling
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)
# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False
@spaces.GPU(duration=30)
def encode_audio(audio_file_path):
global cancel_encode
if audio_file_path is None:
print("No audio file provided")
return None
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Ensure waveform has the correct number of dimensions
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Save to a temporary WAV file
temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
os.close(temp_wav_fd)
torchaudio.save(temp_wav_file_path, waveform, sample_rate)
# Encode the audio
tokens = semanticodec.encode(temp_wav_file_path)
# Convert tokens to NumPy
tokens_numpy = tokens.detach().cpu().numpy()
print(f"Tokens shape: {tokens_numpy.shape}")
# Create temporary .owie file
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd)
with open(temp_file_path, 'wb') as temp_file:
# Write sample rate
temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
# Compress and write the tokens data
compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
print(f"Encoding error: {e}")
return None
finally:
cancel_encode = False
if 'temp_wav_file_path' in locals():
os.remove(temp_wav_file_path)
# Add this function to handle the output
def handle_encode_output(file_path):
if file_path is None:
return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
return file_path, gr.Markdown(visible=False)
@spaces.GPU(duration=30)
def decode_audio(encoded_file_path):
global cancel_decode
try:
# Load encoded data and sample rate
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(-1)
# Move the tensor to the same device as the model
tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
print(f"Model device: {semanticodec.device}")
# Decode the audio
with torch.no_grad():
waveform = semanticodec.decode(tokens)
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
return temp_wav_path
except Exception as e:
print(f"Decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
return str(e)
finally:
cancel_decode = False
@spaces.GPU(duration=30)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
global cancel_stream
try:
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(-1)
# Move the tensor to the same device as the model
tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
print(f"Streaming tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
print(f"Model device: {semanticodec.device}")
# Decode the audio in chunks
chunk_size = sample_rate * 2 # Adjust chunk size as needed
with torch.no_grad():
for i in range(0, tokens.shape[0], chunk_size):
if cancel_stream:
break # Exit the loop if cancellation is requested
tokens_chunk = tokens[i:i+chunk_size]
audio_chunk = semanticodec.decode(tokens_chunk)
# Convert to numpy array and transpose
audio_data = audio_chunk.squeeze(0).cpu().numpy().T
yield (sample_rate, audio_data)
await asyncio.sleep(0) # Allow for cancellation check
except Exception as e:
print(f"Streaming decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
finally:
cancel_stream = False
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath")
encode_button = gr.Button("Encode")
cancel_encode_button = gr.Button("Cancel")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
encode_error_message = gr.Markdown(visible=False)
def encode_wrapper(audio):
if audio is None:
return None, gr.Markdown("Please upload an audio file before encoding.", visible=True)
return handle_encode_output(encode_audio(audio))
encode_button.click(
encode_wrapper,
inputs=input_audio,
outputs=[encoded_output, encode_error_message]
)
cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
decode_button = gr.Button("Decode")
cancel_decode_button = gr.Button("Cancel")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
stream_button = gr.Button("Start Streaming")
cancel_stream_button = gr.Button("Cancel")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)
demo.queue().launch()