dac / app.py
owiedotch's picture
Update app.py
763a29b verified
raw
history blame
No virus
6.25 kB
import gradio as gr
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import spaces
import asyncio # Import asyncio for cancellation
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)
# Global variable for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False
@spaces.GPU(duration=500) # Increased GPU duration to 500 seconds
def encode_audio(audio_file_path):
global cancel_encode
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Encode the audio
audio = waveform.unsqueeze(0).to(torch_device)
with torch.no_grad():
tokens = semanticodec.encode(audio)
# Convert to NumPy and save to a temporary .owie file
tokens_numpy = tokens.detach().cpu().numpy()
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd)
with open(temp_file_path, 'wb') as temp_file:
# Store the sample rate as the first 4 bytes
temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
# Compress and write the encoded data
compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
return f"Encoding error: {e}"
finally:
cancel_encode = False # Reset cancel flag after encoding
@spaces.GPU(duration=500) # Increased GPU duration to 500 seconds
def decode_audio(encoded_file_path):
global cancel_decode
try:
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
# Decode the audio
with torch.no_grad():
waveform = semanticodec.decode(tokens)
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
return temp_wav_path
except Exception as e:
return f"Decoding error: {e}"
finally:
cancel_decode = False # Reset cancel flag after decoding
@spaces.GPU(duration=500) # Increased GPU duration to 500 seconds
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
global cancel_stream
try:
# Load encoded data and sample rate from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
compressed_data = temp_file.read()
tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
tokens = torch.from_numpy(tokens_numpy).to(torch_device)
# Decode the audio in chunks
chunk_size = sample_rate # Use the stored sample rate as chunk size
with torch.no_grad():
for i in range(0, tokens.shape[1], chunk_size):
if cancel_stream:
break # Exit the loop if cancellation is requested
tokens_chunk = tokens[:, i:i+chunk_size]
audio_chunk = semanticodec.decode(tokens_chunk)
# Convert to numpy array and transpose
audio_data = audio_chunk.squeeze(0).cpu().numpy().T
yield (sample_rate, audio_data)
await asyncio.sleep(0) # Allow for cancellation check
except Exception as e:
print(f"Streaming decoding error: {e}")
yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32)) # Return silence
finally:
cancel_stream = False # Reset cancel flag after streaming
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath")
encode_button = gr.Button("Encode")
cancel_encode_button = gr.Button("Cancel")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
cancel_encode_button.click(lambda: globals().update(cancel_encode=True),
outputs=None) # Set cancel_encode flag
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
decode_button = gr.Button("Decode")
cancel_decode_button = gr.Button("Cancel")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
cancel_decode_button.click(lambda: globals().update(cancel_decode=True),
outputs=None) # Set cancel_decode flag
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
stream_button = gr.Button("Start Streaming")
cancel_stream_button = gr.Button("Cancel")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
cancel_stream_button.click(lambda: globals().update(cancel_stream=True),
outputs=None) # Set cancel_stream flag
demo.queue().launch()