dac / app.py
owiedotch's picture
Update app.py
c07b48c verified
raw
history blame
No virus
9.21 kB
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio # Import asyncio for cancellation
import traceback # Import traceback for error handling
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)
# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False
@spaces.GPU(duration=30) # Changed from 250 to 30
def encode_audio(audio_file_path):
global cancel_encode
if audio_file_path is None:
print("No audio file provided")
return None
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Ensure waveform has the correct number of dimensions
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
# Save to a temporary WAV file
temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
os.close(temp_wav_fd)
torchaudio.save(temp_wav_file_path, waveform, sample_rate)
# Encode the audio
tokens = semanticodec.encode(temp_wav_file_path)
# Convert tokens to NumPy and save to .owie file
tokens_numpy = tokens.detach().cpu().numpy()
print(f"Original tokens shape: {tokens_numpy.shape}")
# Ensure tokens_numpy is 2D
if tokens_numpy.ndim == 1:
tokens_numpy = tokens_numpy.reshape(1, -1)
elif tokens_numpy.ndim == 2:
pass # Already 2D
elif tokens_numpy.ndim == 3 and tokens_numpy.shape[0] == 1:
tokens_numpy = tokens_numpy.squeeze(0)
else:
raise ValueError(f"Unexpected tokens array shape: {tokens_numpy.shape}")
print(f"Reshaped tokens shape: {tokens_numpy.shape}")
# Create temporary .owie file
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd)
with open(temp_file_path, 'wb') as temp_file:
# Write sample rate
temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
# Compress and write the tokens data
compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
print(f"Encoding error: {e}")
return None # Return None instead of the error message
finally:
cancel_encode = False # Reset cancel flag after encoding
if 'temp_wav_file_path' in locals():
os.remove(temp_wav_file_path) # Clean up temporary WAV file
# Add this function to handle the output
def handle_encode_output(file_path):
if file_path is None:
return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
return file_path, gr.Markdown(visible=False)
@spaces.GPU(duration=30) # Changed from 250 to 30
def decode_audio(encoded_file_path):
global cancel_decode
try:
# Load encoded data and sample rate
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
# Reshape tokens to match the original shape
tokens_numpy = tokens_numpy.reshape(1, -1, 2)
# Create a writable copy of the numpy array
tokens_numpy = np.array(tokens_numpy, copy=True)
# Move the tensor to the same device as the model
tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
# Debugging prints to check tensor shapes and device
print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
print(f"Model device: {semanticodec.device}")
# Decode the audio
with torch.no_grad():
waveform = semanticodec.decode(tokens)
# Move waveform to CPU for saving
waveform_cpu = waveform.cpu()
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, waveform_cpu.squeeze(0), sample_rate)
return temp_wav_path
except Exception as e:
print(f"Decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
return str(e) # Return error message as string
finally:
cancel_decode = False # Reset cancel flag after decoding
@spaces.GPU(duration=30) # Changed from 250 to 30
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
global cancel_stream
try:
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
tokens_numpy = tokens_numpy.reshape(1, -1, 2)
# Create a writable copy of the numpy array
tokens_numpy = np.array(tokens_numpy, copy=True)
# Move the tensor to the same device as the model
tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)
print(f"Streaming tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
print(f"Model device: {semanticodec.device}")
# Decode the audio in chunks
chunk_size = sample_rate // 2 # Adjust chunk size to account for the new shape
with torch.no_grad():
for i in range(0, tokens.shape[1], chunk_size):
if cancel_stream:
break # Exit the loop if cancellation is requested
tokens_chunk = tokens[:, i:i+chunk_size, :]
audio_chunk = semanticodec.decode(tokens_chunk)
# Convert to numpy array and transpose
audio_data = audio_chunk.squeeze(0).cpu().numpy().T
yield (sample_rate, audio_data)
await asyncio.sleep(0) # Allow for cancellation check
except Exception as e:
print(f"Streaming decoding error: {e}")
print(f"Traceback: {traceback.format_exc()}")
yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
finally:
cancel_stream = False # Reset cancel flag after streaming
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath") # Using "filepath" mode
encode_button = gr.Button("Encode")
cancel_encode_button = gr.Button("Cancel")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
encode_error_message = gr.Markdown(visible=False)
def encode_wrapper(audio):
if audio is None:
return None, gr.Markdown("Please upload an audio file before encoding.", visible=True)
return handle_encode_output(encode_audio(audio))
encode_button.click(
encode_wrapper,
inputs=input_audio,
outputs=[encoded_output, encode_error_message]
)
cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None) # Set cancel_encode flag
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
decode_button = gr.Button("Decode")
cancel_decode_button = gr.Button("Cancel")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath") # Using "filepath" mode
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None) # Set cancel_decode flag
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath") # Using "filepath" mode
stream_button = gr.Button("Start Streaming")
cancel_stream_button = gr.Button("Cancel")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None) # Set cancel_stream flag
demo.queue().launch()