dac / app.py
owiedotch's picture
Create app.py
f18f98b verified
raw
history blame
No virus
3.57 kB
import gradio as gr
import spaces
import torch
import dac
import numpy as np
from pydub import AudioSegment
from audiotools import AudioSignal
import io
import soundfile as sf
class DACApi:
def __init__(self, model_type="44khz", model_bitrate="16kbps"):
self.model_type = model_type
self.model_bitrate = model_bitrate
self.model_path = dac.utils.download(model_type, model_bitrate)
print("Loading DAC model...")
self.model = dac.DAC.load(self.model_path)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model.to(self.device)
@spaces.GPU
def encode_audio(self, input_file):
# Convert MP3 to WAV if necessary
if input_file.name.lower().endswith('.mp3'):
print("Converting MP3 to WAV...")
audio = AudioSegment.from_mp3(input_file.name)
input_wav = io.BytesIO()
audio.export(input_wav, format="wav")
input_wav.seek(0)
else:
input_wav = input_file
# Load audio signal
signal = AudioSignal(input_wav)
# Compress audio
print("Compressing audio...")
compressed = self.model.compress(signal)
output = io.BytesIO()
compressed.save(output)
output.seek(0)
return output
@spaces.GPU
def decode_audio(self, input_file):
# Load compressed audio
print("Loading compressed audio...")
compressed = dac.DACFile.load(input_file.name)
# Decompress audio
print("Decompressing audio...")
decompressed = self.model.decompress(compressed)
output = io.BytesIO()
decompressed.write(output, format='wav')
output.seek(0)
return output
@spaces.GPU
def stream_audio(self, input_file):
# Load compressed audio
print("Loading compressed audio...")
compressed = dac.DACFile.load(input_file.name)
# Decompress audio
print("Decompressing audio...")
decompressed = self.model.decompress(compressed)
audio_data = decompressed.audio_data.cpu().numpy().squeeze().T
sample_rate = decompressed.sample_rate
return (sample_rate, audio_data)
dac_api = DACApi()
def encode(audio):
compressed = dac_api.encode_audio(audio)
return compressed
def decode(audio):
decompressed = dac_api.decode_audio(audio)
return decompressed
def stream(audio):
sample_rate, audio_data = dac_api.stream_audio(audio)
return (sample_rate, audio_data)
# Gradio interface
with gr.Blocks() as demo:
with gr.Tab("Encode"):
with gr.Row():
input_audio = gr.Audio(type="filepath", label="Input Audio")
output_file = gr.File(label="Compressed DAC File")
encode_button = gr.Button("Encode")
encode_button.click(encode, inputs=[input_audio], outputs=[output_file])
with gr.Tab("Decode"):
with gr.Row():
input_file = gr.File(label="Compressed DAC File")
output_audio = gr.Audio(label="Decompressed Audio")
decode_button = gr.Button("Decode")
decode_button.click(decode, inputs=[input_file], outputs=[output_audio])
with gr.Tab("Stream"):
with gr.Row():
stream_input = gr.File(label="Compressed DAC File")
stream_output = gr.Audio(label="Streamed Audio")
stream_button = gr.Button("Stream")
stream_button.click(stream, inputs=[stream_input], outputs=[stream_output])
if __name__ == "__main__":
demo.launch()