File size: 3,380 Bytes
f18f98b
c227215
f18f98b
 
 
c6e9cf1
 
 
f18f98b
 
 
 
 
 
 
c99df4b
 
f18f98b
c99df4b
c6e9cf1
44bab11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f18f98b
c99df4b
c6e9cf1
44bab11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f18f98b
 
 
 
 
 
 
c6e9cf1
 
f18f98b
 
 
 
c6e9cf1
44bab11
f18f98b
c6e9cf1
f18f98b
c6e9cf1
44bab11
c6e9cf1
44bab11
f18f98b
c6e9cf1
f18f98b
c6e9cf1
44bab11
c6e9cf1
f18f98b
44bab11
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
import spaces
import torch
import dac
import io
from audiotools import AudioSignal
from pydub import AudioSegment

class DACApi:
    def __init__(self, model_type="44khz", model_bitrate="16kbps"):
        self.model_type = model_type
        self.model_bitrate = model_bitrate
        self.model_path = dac.utils.download(model_type, model_bitrate)
        print("Loading DAC model...")
        self.model = dac.DAC.load(self.model_path)
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.model.to(self.device)

    @spaces.GPU
    def encode_audio(self, audio):
        try:
            # Convert audio to WAV
            audio = AudioSegment.from_file(audio.name)
            wav_io = io.BytesIO()
            audio.export(wav_io, format="wav")
            wav_io.seek(0)

            # Load audio signal
            signal = AudioSignal(wav_io)
            signal = signal.to(self.device)

            # Compress audio within a no_grad context
            with torch.no_grad():
                print("Compressing audio...")  # You can keep this for console logging
                compressed = self.model.compress(signal)

            # Detach the compressed tensor (additional safety measure)
            compressed_detached = compressed.detach()

            # Save compressed audio to BytesIO
            output = io.BytesIO()
            compressed_detached.save(output)
            output.seek(0)

            return output

        except Exception as e:
            # Display error message in a popup
            gr.Warning(f"An error occurred during encoding: {e}")
            return None  # Return None to indicate failure

    @spaces.GPU
    def decode_audio(self, compressed_file):
        try:
            # Load compressed audio
            compressed = dac.DACFile.load(compressed_file)
            compressed = compressed.to(self.device)

            # Decompress audio
            print("Decompressing audio...")  # You can keep this for console logging
            decompressed = self.model.decompress(compressed)

            # Save decompressed audio to BytesIO
            output = io.BytesIO()
            decompressed.write(output, format='wav')
            output.seek(0)

            return output

        except Exception as e:
            # Display error message in a popup
            gr.Warning(f"An error occurred during decoding: {e}")
            return None  # Return None to indicate failure

dac_api = DACApi()

def encode(audio):
    compressed = dac_api.encode_audio(audio)
    return compressed

def decode(compressed_file):
    decompressed = dac_api.decode_audio(compressed_file)
    return decompressed

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Audio Compression with DAC")

    with gr.Tab("Encode"):
        audio_input = gr.Audio(type="filepath", label="Input Audio")
        encode_button = gr.Button("Encode")
        encoded_output = gr.File(label="Compressed Audio")

        encode_button.click(encode, inputs=audio_input, outputs=encoded_output)

    with gr.Tab("Decode"):
        compressed_input = gr.File(label="Compressed Audio")
        decode_button = gr.Button("Decode")
        decoded_output = gr.Audio(label="Decompressed Audio")

        decode_button.click(decode, inputs=compressed_input, outputs=decoded_output)

demo.queue().launch()