|
import gradio as gr |
|
import spaces |
|
import torch |
|
import dac |
|
import numpy as np |
|
from pydub import AudioSegment |
|
from audiotools import AudioSignal |
|
import io |
|
import soundfile as sf |
|
|
|
class DACApi: |
|
def __init__(self, model_type="44khz", model_bitrate="16kbps"): |
|
self.model_type = model_type |
|
self.model_bitrate = model_bitrate |
|
self.model_path = dac.utils.download(model_type, model_bitrate) |
|
print("Loading DAC model...") |
|
self.model = dac.DAC.load(self.model_path) |
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.model.to(self.device) |
|
|
|
@spaces.GPU |
|
def encode_audio(self, input_file): |
|
|
|
if not input_file.name.lower().endswith('.wav'): |
|
print(f"Converting {input_file.name} to WAV...") |
|
audio = AudioSegment.from_file(input_file.name) |
|
input_wav = io.BytesIO() |
|
audio.export(input_wav, format="wav") |
|
input_wav.seek(0) |
|
else: |
|
input_wav = input_file |
|
|
|
|
|
signal = AudioSignal(input_wav) |
|
|
|
|
|
print("Compressing audio...") |
|
compressed = self.model.compress(signal) |
|
|
|
output = io.BytesIO() |
|
compressed.save(output) |
|
output.seek(0) |
|
return output |
|
|
|
@spaces.GPU |
|
def decode_audio(self, input_file): |
|
|
|
print("Loading compressed audio...") |
|
compressed = dac.DACFile.load(input_file.name) |
|
|
|
|
|
print("Decompressing audio...") |
|
decompressed = self.model.decompress(compressed) |
|
|
|
output = io.BytesIO() |
|
decompressed.write(output, format='wav') |
|
output.seek(0) |
|
return output |
|
|
|
@spaces.GPU |
|
def stream_audio(self, input_file): |
|
|
|
print("Loading compressed audio...") |
|
compressed = dac.DACFile.load(input_file.name) |
|
|
|
|
|
print("Decompressing audio...") |
|
decompressed = self.model.decompress(compressed) |
|
|
|
audio_data = decompressed.audio_data.cpu().detach().numpy().squeeze().T |
|
sample_rate = decompressed.sample_rate |
|
|
|
return (sample_rate, audio_data) |
|
|
|
dac_api = DACApi() |
|
|
|
def encode(audio): |
|
if isinstance(audio, torch.Tensor): |
|
audio = audio.detach() |
|
compressed = dac_api.encode_audio(audio) |
|
return compressed |
|
|
|
def decode(audio): |
|
if isinstance(audio, torch.Tensor): |
|
audio = audio.detach() |
|
decompressed = dac_api.decode_audio(audio) |
|
return decompressed |
|
|
|
def stream(audio): |
|
if isinstance(audio, torch.Tensor): |
|
audio = audio.detach() |
|
sample_rate, audio_data = dac_api.stream_audio(audio) |
|
return (sample_rate, audio_data) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
|
with gr.Tab("Encode"): |
|
with gr.Row(): |
|
input_audio = gr.Audio(type="filepath", label="Input Audio") |
|
output_file = gr.File(label="Compressed DAC File") |
|
encode_button = gr.Button("Encode") |
|
encode_button.click(encode, inputs=[input_audio], outputs=[output_file]) |
|
|
|
with gr.Tab("Decode"): |
|
with gr.Row(): |
|
input_file = gr.File(label="Compressed DAC File") |
|
output_audio = gr.Audio(label="Decompressed Audio") |
|
decode_button = gr.Button("Decode") |
|
decode_button.click(decode, inputs=[input_file], outputs=[output_audio]) |
|
|
|
with gr.Tab("Stream"): |
|
with gr.Row(): |
|
stream_input = gr.File(label="Compressed DAC File") |
|
stream_output = gr.Audio(label="Streamed Audio") |
|
stream_button = gr.Button("Stream") |
|
stream_button.click(stream, inputs=[stream_input], outputs=[stream_output]) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |