dac / app.py
owiedotch's picture
Update app.py
c227215 verified
raw
history blame
No virus
2.59 kB
import gradio as gr
import spaces
import torch
import dac
import io
from audiotools import AudioSignal
from pydub import AudioSegment
class DACApi:
def __init__(self, model_type="44khz", model_bitrate="16kbps"):
self.model_type = model_type
self.model_bitrate = model_bitrate
self.model_path = dac.utils.download(model_type, model_bitrate)
print("Loading DAC model...")
self.model = dac.DAC.load(self.model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
@spaces.GPU
def encode_audio(self, audio):
# Convert audio to WAV
audio = AudioSegment.from_file(audio.name)
wav_io = io.BytesIO()
audio.export(wav_io, format="wav")
wav_io.seek(0)
# Load audio signal
signal = AudioSignal(wav_io)
signal = signal.to(self.device)
# Compress audio
print("Compressing audio...")
compressed = self.model.compress(signal)
# Save compressed audio to BytesIO
output = io.BytesIO()
compressed.save(output)
output.seek(0)
return output
@spaces.GPU
def decode_audio(self, compressed_file):
# Load compressed audio
compressed = dac.DACFile.load(compressed_file)
compressed = compressed.to(self.device)
# Decompress audio
print("Decompressing audio...")
decompressed = self.model.decompress(compressed)
# Save decompressed audio to BytesIO
output = io.BytesIO()
decompressed.write(output, format='wav')
output.seek(0)
return output
dac_api = DACApi()
def encode(audio):
compressed = dac_api.encode_audio(audio)
return compressed
def decode(compressed_file):
decompressed = dac_api.decode_audio(compressed_file)
return decompressed
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Audio Compression with DAC")
with gr.Tab("Encode"):
audio_input = gr.Audio(type="filepath", label="Input Audio")
encode_button = gr.Button("Encode")
encoded_output = gr.File(label="Compressed Audio")
encode_button.click(encode, inputs=audio_input, outputs=encoded_output)
with gr.Tab("Decode"):
compressed_input = gr.File(label="Compressed Audio")
decode_button = gr.Button("Decode")
decoded_output = gr.Audio(label="Decompressed Audio")
decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)
demo.launch()