File size: 3,565 Bytes
f18f98b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
import spaces
import torch
import dac
import numpy as np
from pydub import AudioSegment
from audiotools import AudioSignal
import io
import soundfile as sf

class DACApi:
    def __init__(self, model_type="44khz", model_bitrate="16kbps"):
        self.model_type = model_type
        self.model_bitrate = model_bitrate
        self.model_path = dac.utils.download(model_type, model_bitrate)
        print("Loading DAC model...")
        self.model = dac.DAC.load(self.model_path)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)

    @spaces.GPU
    def encode_audio(self, input_file):
        # Convert MP3 to WAV if necessary
        if input_file.name.lower().endswith('.mp3'):
            print("Converting MP3 to WAV...")
            audio = AudioSegment.from_mp3(input_file.name)
            input_wav = io.BytesIO()
            audio.export(input_wav, format="wav")
            input_wav.seek(0)
        else:
            input_wav = input_file

        # Load audio signal
        signal = AudioSignal(input_wav)

        # Compress audio
        print("Compressing audio...")
        compressed = self.model.compress(signal)

        output = io.BytesIO()
        compressed.save(output)
        output.seek(0)
        return output

    @spaces.GPU
    def decode_audio(self, input_file):
        # Load compressed audio
        print("Loading compressed audio...")
        compressed = dac.DACFile.load(input_file.name)

        # Decompress audio
        print("Decompressing audio...")
        decompressed = self.model.decompress(compressed)

        output = io.BytesIO()
        decompressed.write(output, format='wav')
        output.seek(0)
        return output

    @spaces.GPU
    def stream_audio(self, input_file):
        # Load compressed audio
        print("Loading compressed audio...")
        compressed = dac.DACFile.load(input_file.name)

        # Decompress audio
        print("Decompressing audio...")
        decompressed = self.model.decompress(compressed)

        audio_data = decompressed.audio_data.cpu().numpy().squeeze().T
        sample_rate = decompressed.sample_rate

        return (sample_rate, audio_data)

dac_api = DACApi()

def encode(audio):
    compressed = dac_api.encode_audio(audio)
    return compressed

def decode(audio):
    decompressed = dac_api.decode_audio(audio)
    return decompressed

def stream(audio):
    sample_rate, audio_data = dac_api.stream_audio(audio)
    return (sample_rate, audio_data)

# Gradio interface
with gr.Blocks() as demo:
    
    with gr.Tab("Encode"):
        with gr.Row():
            input_audio = gr.Audio(type="filepath", label="Input Audio")
            output_file = gr.File(label="Compressed DAC File")
        encode_button = gr.Button("Encode")
        encode_button.click(encode, inputs=[input_audio], outputs=[output_file])
    
    with gr.Tab("Decode"):
        with gr.Row():
            input_file = gr.File(label="Compressed DAC File")
            output_audio = gr.Audio(label="Decompressed Audio")
        decode_button = gr.Button("Decode")
        decode_button.click(decode, inputs=[input_file], outputs=[output_audio])
    
    with gr.Tab("Stream"):
        with gr.Row():
            stream_input = gr.File(label="Compressed DAC File")
            stream_output = gr.Audio(label="Streamed Audio")
        stream_button = gr.Button("Stream")
        stream_button.click(stream, inputs=[stream_input], outputs=[stream_output])

if __name__ == "__main__":
    demo.launch()