File size: 8,597 Bytes
f18f98b
6fa15d7
eb0f782
 
763a29b
a2cc897
bd40662
eb0f782
 
 
763a29b
c707064
bd40662
555a678
9f23276
a7dbbfe
50a007c
763a29b
 
 
c9a89ac
cefd33c
eb0f782
763a29b
841c4fd
306e4c8
 
 
 
841c4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cefd33c
841c4fd
 
cefd33c
841c4fd
 
 
 
 
 
 
fea8244
 
 
 
841c4fd
 
b8668a7
841c4fd
 
 
 
 
 
cefd33c
841c4fd
 
cefd33c
73fdeaf
cefd33c
73fdeaf
 
 
 
306e4c8
73fdeaf
763a29b
cefd33c
eb0f782
763a29b
4ca3581
841c4fd
 
 
 
b8668a7
 
 
 
841c4fd
b8668a7
841c4fd
5920386
555a678
5920386
 
 
 
 
 
 
555a678
9e51fbe
841c4fd
 
 
4ca3581
555a678
 
7634a29
 
 
841c4fd
 
7634a29
841c4fd
4ca3581
841c4fd
 
c707064
cefd33c
841c4fd
 
cefd33c
763a29b
cefd33c
763a29b
 
841c4fd
eb0f782
d24bcbd
eb0f782
d24bcbd
b8668a7
 
 
 
763a29b
b8668a7
ff92e4e
5920386
555a678
5920386
 
 
 
 
 
 
555a678
9f23276
eb0f782
7634a29
eb0f782
b8668a7
763a29b
 
 
b8668a7
763a29b
d888fa7
 
 
763a29b
c27eb74
 
eb0f782
c707064
763a29b
c27eb74
763a29b
cefd33c
d24bcbd
eb0f782
d07be48
763a29b
44bab11
f18f98b
cefd33c
eb0f782
763a29b
cefd33c
73fdeaf
44bab11
306e4c8
 
 
 
 
73fdeaf
306e4c8
73fdeaf
 
 
cefd33c
44bab11
f18f98b
cefd33c
eb0f782
763a29b
cefd33c
eb0f782
 
cefd33c
eb0f782
 
cefd33c
eb0f782
763a29b
d888fa7
44bab11
eb0f782
cefd33c
f18f98b
73fdeaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio  # Import asyncio for cancellation
import traceback  # Import traceback for error handling

# Load the SemantiCodec model without specifying a device
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768)

# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False

@spaces.GPU(duration=30)
def encode_audio(audio_file_path):
    global cancel_encode

    if audio_file_path is None:
        print("No audio file provided")
        return None

    try:
        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Ensure waveform has the correct number of dimensions
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        # Save to a temporary WAV file
        temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
        os.close(temp_wav_fd)
        torchaudio.save(temp_wav_file_path, waveform, sample_rate)

        # Encode the audio
        tokens = semanticodec.encode(temp_wav_file_path)

        # Convert tokens to NumPy
        tokens_numpy = tokens.detach().cpu().numpy()

        print(f"Tokens shape: {tokens_numpy.shape}")

        # Create temporary .owie file
        temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
        os.close(temp_fd)
        with open(temp_file_path, 'wb') as temp_file:
            # Write sample rate
            temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
            # Write shape information
            temp_file.write(len(tokens_numpy.shape).to_bytes(4, byteorder='little'))
            for dim in tokens_numpy.shape:
                temp_file.write(dim.to_bytes(4, byteorder='little'))
            # Compress and write the tokens data
            compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
            temp_file.write(len(compressed_data).to_bytes(4, byteorder='little'))
            temp_file.write(compressed_data)

        return temp_file_path

    except Exception as e:
        print(f"Encoding error: {e}")
        return None

    finally:
        cancel_encode = False
        if 'temp_wav_file_path' in locals():
            os.remove(temp_wav_file_path)

# Add this function to handle the output
def handle_encode_output(file_path):
    if file_path is None:
        return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
    return file_path, gr.Markdown(visible=False)

@spaces.GPU(duration=30)
def decode_audio(encoded_file_path):
    global cancel_decode

    try:
        # Load encoded data and sample rate
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            ndim = int.from_bytes(temp_file.read(4), byteorder='little')
            shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim))
            compressed_size = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read(compressed_size)
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape)

        # Create a tensor from the numpy array
        tokens = torch.from_numpy(tokens_numpy)
        
        # Determine the device of the model
        model_device = next(semanticodec.parameters()).device
        print(f"Model device: {model_device}")

        # Move the tokens to the same device as the model
        tokens = tokens.to(model_device)
        print(f"Tokens device: {tokens.device}")

        # Decode the audio
        with torch.no_grad():
            waveform = semanticodec.decode(tokens)

        print(f"Waveform device: {waveform.device}")

        # Move waveform to CPU for saving
        waveform_cpu = waveform.cpu()

        # Save to a temporary WAV file
        temp_wav_path = tempfile.mktemp(suffix=".wav")
        torchaudio.save(temp_wav_path, waveform_cpu.squeeze(0), sample_rate)
        return temp_wav_path

    except Exception as e:
        print(f"Decoding error: {e}")
        print(f"Traceback: {traceback.format_exc()}")
        return str(e)

    finally:
        cancel_decode = False

@spaces.GPU(duration=30)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
    global cancel_stream

    try:
        # Load encoded data and sample rate from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            ndim = int.from_bytes(temp_file.read(4), byteorder='little')
            shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim))
            compressed_size = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read(compressed_size)
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape)
            
        # Create a tensor from the numpy array
        tokens = torch.from_numpy(tokens_numpy)
        
        # Determine the device of the model
        model_device = next(semanticodec.parameters()).device
        print(f"Model device: {model_device}")

        # Move the tokens to the same device as the model
        tokens = tokens.to(model_device)
        print(f"Streaming tokens device: {tokens.device}")

        # Decode the audio in chunks
        chunk_size = sample_rate * 2  # Adjust chunk size as needed
        with torch.no_grad():
            for i in range(0, tokens.shape[1], chunk_size):
                if cancel_stream:
                    break  # Exit the loop if cancellation is requested

                tokens_chunk = tokens[:, i:i+chunk_size, :]
                audio_chunk = semanticodec.decode(tokens_chunk)
                # Convert to numpy array and transpose
                audio_data = audio_chunk.squeeze(0).cpu().numpy().T
                yield (sample_rate, audio_data)
                await asyncio.sleep(0)  # Allow for cancellation check

    except Exception as e:
        print(f"Streaming decoding error: {e}")
        print(f"Traceback: {traceback.format_exc()}")
        yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32))  # Return silence

    finally:
        cancel_stream = False

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")

    with gr.Tab("Encode"):
        input_audio = gr.Audio(label="Input Audio", type="filepath")
        encode_button = gr.Button("Encode")
        cancel_encode_button = gr.Button("Cancel")
        encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
        encode_error_message = gr.Markdown(visible=False)

        def encode_wrapper(audio):
            if audio is None:
                return None, gr.Markdown("Please upload an audio file before encoding.", visible=True)
            return handle_encode_output(encode_audio(audio))

        encode_button.click(
            encode_wrapper,
            inputs=input_audio,
            outputs=[encoded_output, encode_error_message]
        )
        cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)

    with gr.Tab("Decode"):
        input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
        decode_button = gr.Button("Decode")
        cancel_decode_button = gr.Button("Cancel")
        decoded_output = gr.Audio(label="Decoded Audio", type="filepath")

        decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
        cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)

    with gr.Tab("Streaming"):
        input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
        stream_button = gr.Button("Start Streaming")
        cancel_stream_button = gr.Button("Cancel")
        audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)

        stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
        cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)

demo.queue().launch()