File size: 8,257 Bytes
f18f98b
6fa15d7
eb0f782
 
763a29b
a2cc897
bd40662
eb0f782
 
 
763a29b
c707064
bd40662
eb0f782
 
 
 
 
 
 
c27eb74
763a29b
 
a7dbbfe
50a007c
763a29b
 
 
c9a89ac
cefd33c
eb0f782
763a29b
841c4fd
306e4c8
 
 
 
841c4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cefd33c
841c4fd
 
cefd33c
841c4fd
 
 
 
 
 
 
fea8244
 
 
 
841c4fd
 
 
 
 
 
 
 
cefd33c
841c4fd
 
cefd33c
73fdeaf
cefd33c
73fdeaf
 
 
 
306e4c8
73fdeaf
763a29b
cefd33c
eb0f782
763a29b
4ca3581
841c4fd
 
 
 
fea8244
 
841c4fd
 
fea8244
841c4fd
ff92e4e
 
 
 
 
d87908f
841c4fd
 
 
4ca3581
841c4fd
 
cefd33c
841c4fd
4ca3581
841c4fd
 
c707064
cefd33c
841c4fd
 
cefd33c
763a29b
cefd33c
763a29b
 
841c4fd
eb0f782
d24bcbd
eb0f782
d24bcbd
fea8244
 
eb0f782
763a29b
fea8244
ff92e4e
 
 
 
 
 
c707064
eb0f782
fea8244
eb0f782
fea8244
763a29b
 
 
fea8244
763a29b
d888fa7
 
 
763a29b
c27eb74
 
eb0f782
c707064
763a29b
c27eb74
763a29b
cefd33c
d24bcbd
eb0f782
d07be48
763a29b
44bab11
f18f98b
cefd33c
eb0f782
763a29b
cefd33c
73fdeaf
44bab11
306e4c8
 
 
 
 
73fdeaf
306e4c8
73fdeaf
 
 
cefd33c
44bab11
f18f98b
cefd33c
eb0f782
763a29b
cefd33c
eb0f782
 
cefd33c
eb0f782
 
cefd33c
eb0f782
763a29b
d888fa7
44bab11
eb0f782
cefd33c
f18f98b
73fdeaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio  # Import asyncio for cancellation
import traceback  # Import traceback for error handling

# Attempt to use GPU, fallback to CPU
try:
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {torch_device}")
except Exception as e:
    print(f"Error detecting GPU. Using CPU. Error: {e}")
    torch_device = torch.device("cpu")

# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)

# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False

@spaces.GPU(duration=30)
def encode_audio(audio_file_path):
    global cancel_encode

    if audio_file_path is None:
        print("No audio file provided")
        return None

    try:
        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Ensure waveform has the correct number of dimensions
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        # Save to a temporary WAV file
        temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
        os.close(temp_wav_fd)
        torchaudio.save(temp_wav_file_path, waveform, sample_rate)

        # Encode the audio
        tokens = semanticodec.encode(temp_wav_file_path)

        # Convert tokens to NumPy
        tokens_numpy = tokens.detach().cpu().numpy()

        print(f"Tokens shape: {tokens_numpy.shape}")

        # Create temporary .owie file
        temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
        os.close(temp_fd)
        with open(temp_file_path, 'wb') as temp_file:
            # Write sample rate
            temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
            # Write shape information
            temp_file.write(len(tokens_numpy.shape).to_bytes(4, byteorder='little'))
            for dim in tokens_numpy.shape:
                temp_file.write(dim.to_bytes(4, byteorder='little'))
            # Compress and write the tokens data
            compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
            temp_file.write(compressed_data)

        return temp_file_path

    except Exception as e:
        print(f"Encoding error: {e}")
        return None

    finally:
        cancel_encode = False
        if 'temp_wav_file_path' in locals():
            os.remove(temp_wav_file_path)

# Add this function to handle the output
def handle_encode_output(file_path):
    if file_path is None:
        return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
    return file_path, gr.Markdown(visible=False)

@spaces.GPU(duration=30)
def decode_audio(encoded_file_path):
    global cancel_decode

    try:
        # Load encoded data and sample rate
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            ndim = int.from_bytes(temp_file.read(4), byteorder='little')
            shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim))
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape)

            # Move the tensor to the same device as the model
            tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)

        print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
        print(f"Model device: {semanticodec.device}")

        # Decode the audio
        with torch.no_grad():
            waveform = semanticodec.decode(tokens)

        # Save to a temporary WAV file
        temp_wav_path = tempfile.mktemp(suffix=".wav")
        torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
        return temp_wav_path

    except Exception as e:
        print(f"Decoding error: {e}")
        print(f"Traceback: {traceback.format_exc()}")
        return str(e)

    finally:
        cancel_decode = False

@spaces.GPU(duration=30)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
    global cancel_stream

    try:
        # Load encoded data and sample rate from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            ndim = int.from_bytes(temp_file.read(4), byteorder='little')
            shape = tuple(int.from_bytes(temp_file.read(4), byteorder='little') for _ in range(ndim))
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64).reshape(shape)
            
            # Move the tensor to the same device as the model
            tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)

        print(f"Streaming tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
        print(f"Model device: {semanticodec.device}")

        # Decode the audio in chunks
        chunk_size = sample_rate // 2  # Adjust chunk size to account for the new shape
        with torch.no_grad():
            for i in range(0, tokens.shape[1], chunk_size):
                if cancel_stream:
                    break  # Exit the loop if cancellation is requested

                tokens_chunk = tokens[:, i:i+chunk_size, :]
                audio_chunk = semanticodec.decode(tokens_chunk)
                # Convert to numpy array and transpose
                audio_data = audio_chunk.squeeze(0).cpu().numpy().T
                yield (sample_rate, audio_data)
                await asyncio.sleep(0)  # Allow for cancellation check

    except Exception as e:
        print(f"Streaming decoding error: {e}")
        print(f"Traceback: {traceback.format_exc()}")
        yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32))  # Return silence

    finally:
        cancel_stream = False

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")

    with gr.Tab("Encode"):
        input_audio = gr.Audio(label="Input Audio", type="filepath")
        encode_button = gr.Button("Encode")
        cancel_encode_button = gr.Button("Cancel")
        encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
        encode_error_message = gr.Markdown(visible=False)

        def encode_wrapper(audio):
            if audio is None:
                return None, gr.Markdown("Please upload an audio file before encoding.", visible=True)
            return handle_encode_output(encode_audio(audio))

        encode_button.click(
            encode_wrapper,
            inputs=input_audio,
            outputs=[encoded_output, encode_error_message]
        )
        cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)

    with gr.Tab("Decode"):
        input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
        decode_button = gr.Button("Decode")
        cancel_decode_button = gr.Button("Cancel")
        decoded_output = gr.Audio(label="Decoded Audio", type="filepath")

        decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
        cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)

    with gr.Tab("Streaming"):
        input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
        stream_button = gr.Button("Start Streaming")
        cancel_stream_button = gr.Button("Cancel")
        audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)

        stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
        cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)

demo.queue().launch()