dac / app.py
owiedotch's picture
Update app.py
44bab11 verified
raw
history blame
No virus
3.38 kB
import gradio as gr
import spaces
import torch
import dac
import io
from audiotools import AudioSignal
from pydub import AudioSegment
class DACApi:
def __init__(self, model_type="44khz", model_bitrate="16kbps"):
self.model_type = model_type
self.model_bitrate = model_bitrate
self.model_path = dac.utils.download(model_type, model_bitrate)
print("Loading DAC model...")
self.model = dac.DAC.load(self.model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
@spaces.GPU
def encode_audio(self, audio):
try:
# Convert audio to WAV
audio = AudioSegment.from_file(audio.name)
wav_io = io.BytesIO()
audio.export(wav_io, format="wav")
wav_io.seek(0)
# Load audio signal
signal = AudioSignal(wav_io)
signal = signal.to(self.device)
# Compress audio within a no_grad context
with torch.no_grad():
print("Compressing audio...") # You can keep this for console logging
compressed = self.model.compress(signal)
# Detach the compressed tensor (additional safety measure)
compressed_detached = compressed.detach()
# Save compressed audio to BytesIO
output = io.BytesIO()
compressed_detached.save(output)
output.seek(0)
return output
except Exception as e:
# Display error message in a popup
gr.Warning(f"An error occurred during encoding: {e}")
return None # Return None to indicate failure
@spaces.GPU
def decode_audio(self, compressed_file):
try:
# Load compressed audio
compressed = dac.DACFile.load(compressed_file)
compressed = compressed.to(self.device)
# Decompress audio
print("Decompressing audio...") # You can keep this for console logging
decompressed = self.model.decompress(compressed)
# Save decompressed audio to BytesIO
output = io.BytesIO()
decompressed.write(output, format='wav')
output.seek(0)
return output
except Exception as e:
# Display error message in a popup
gr.Warning(f"An error occurred during decoding: {e}")
return None # Return None to indicate failure
dac_api = DACApi()
def encode(audio):
compressed = dac_api.encode_audio(audio)
return compressed
def decode(compressed_file):
decompressed = dac_api.decode_audio(compressed_file)
return decompressed
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Audio Compression with DAC")
with gr.Tab("Encode"):
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode")
encoded_output = gr.File(label="Compressed Audio")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
compressed_input = gr.File(label="Compressed Audio")
decode_button = gr.Button("Decode")
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.queue().launch()