dac / app.py
owiedotch's picture
Update app.py
dd53483 verified
raw
history blame
No virus
2.62 kB
import gradio as gr
import spaces
import torch
import torchaudio
from encodec import EncodecModel
from encodec.utils import convert_audio
from encodec.compress import compress_to_file, decompress_from_file
import io
# Check for CUDA availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the Encodec model and move it to the selected device
model = EncodecModel.encodec_model_48khz().to(device)
model.set_target_bandwidth(6.0)
@spaces.GPU # Indicate GPU usage for Spaces environment (if applicable)
def encode(audio_file_path):
try:
# Load and pre-process the audio waveform
wav, sr = torchaudio.load(audio_file_path)
# Convert to mono if necessary
if wav.shape[0] > 1:
wav = torch.mean(wav, dim=0, keepdim=True)
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
wav = wav.to(device) # Move the input audio to the selected device
# Compress to ecdc file in memory
output = io.BytesIO()
compress_to_file(model, wav, output)
output.seek(0)
return output
except Exception as e:
gr.Warning(f"An error occurred during encoding: {e}")
return None
@spaces.GPU
def decode(compressed_audio_file):
try:
# Decompress audio
wav, sr = decompress_from_file(compressed_audio_file, device=device) # Pass the device to decompress_from_file
# Convert the decoded audio to a numpy array for Gradio output
decoded_audio = wav.cpu().numpy()
return decoded_audio
except Exception as e:
gr.Warning(f"An error occurred during decoding: {e}")
return None
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("<h1 style='text-align: center;'>Audio Compression with Encodec</h1>")
with gr.Tab("Encode"):
with gr.Row():
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode", variant="primary")
with gr.Row():
encoded_output = gr.File(label="Compressed Audio (.ecdc)")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
with gr.Row():
compressed_input = gr.File(label="Compressed Audio (.ecdc)")
decode_button = gr.Button("Decode", variant="primary")
with gr.Row():
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.queue().launch()