File size: 9,132 Bytes
f18f98b
6fa15d7
eb0f782
 
763a29b
a2cc897
bd40662
eb0f782
 
 
763a29b
c707064
bd40662
eb0f782
 
 
 
 
 
 
c27eb74
763a29b
 
a7dbbfe
50a007c
763a29b
 
 
c9a89ac
306e4c8
eb0f782
763a29b
841c4fd
306e4c8
 
 
 
841c4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b5adeb
 
73fdeaf
 
 
7b5adeb
 
 
 
 
 
 
 
841c4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fdeaf
841c4fd
 
 
73fdeaf
 
 
 
 
 
306e4c8
73fdeaf
763a29b
306e4c8
eb0f782
763a29b
4ca3581
841c4fd
 
 
 
 
 
 
 
c707064
 
086a0ea
ff92e4e
 
841c4fd
ff92e4e
 
 
 
 
 
d87908f
841c4fd
 
 
4ca3581
841c4fd
 
 
 
4ca3581
841c4fd
 
c707064
50a007c
841c4fd
 
 
763a29b
306e4c8
763a29b
 
841c4fd
eb0f782
d24bcbd
eb0f782
d24bcbd
eb0f782
763a29b
841c4fd
c707064
ff92e4e
 
 
 
 
 
 
 
 
c707064
eb0f782
c707064
eb0f782
763a29b
 
 
 
c707064
763a29b
d888fa7
 
 
763a29b
c27eb74
 
eb0f782
c707064
763a29b
c27eb74
763a29b
 
d24bcbd
eb0f782
d07be48
763a29b
44bab11
f18f98b
6eabaea
eb0f782
763a29b
6eabaea
73fdeaf
44bab11
306e4c8
 
 
 
 
73fdeaf
306e4c8
73fdeaf
 
 
841c4fd
44bab11
f18f98b
6eabaea
eb0f782
763a29b
6eabaea
eb0f782
 
841c4fd
eb0f782
 
6eabaea
eb0f782
763a29b
d888fa7
44bab11
eb0f782
841c4fd
f18f98b
73fdeaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio  # Import asyncio for cancellation
import traceback  # Import traceback for error handling

# Attempt to use GPU, fallback to CPU
try:
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {torch_device}")
except Exception as e:
    print(f"Error detecting GPU. Using CPU. Error: {e}")
    torch_device = torch.device("cpu")

# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)

# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False

@spaces.GPU(duration=30)  # Changed from 250 to 30
def encode_audio(audio_file_path):
    global cancel_encode

    if audio_file_path is None:
        print("No audio file provided")
        return None

    try:
        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Ensure waveform has the correct number of dimensions
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        # Save to a temporary WAV file
        temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
        os.close(temp_wav_fd)
        torchaudio.save(temp_wav_file_path, waveform, sample_rate)

        # Encode the audio
        tokens = semanticodec.encode(temp_wav_file_path)

        # Convert tokens to NumPy and save to .owie file
        tokens_numpy = tokens.detach().cpu().numpy()

        print(f"Original tokens shape: {tokens_numpy.shape}")

        # Ensure tokens_numpy is 2D
        if tokens_numpy.ndim == 1:
            tokens_numpy = tokens_numpy.reshape(1, -1)
        elif tokens_numpy.ndim == 2:
            pass  # Already 2D
        elif tokens_numpy.ndim == 3 and tokens_numpy.shape[0] == 1:
            tokens_numpy = tokens_numpy.squeeze(0)
        else:
            raise ValueError(f"Unexpected tokens array shape: {tokens_numpy.shape}")

        print(f"Reshaped tokens shape: {tokens_numpy.shape}")

        # Create temporary .owie file
        temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
        os.close(temp_fd)
        with open(temp_file_path, 'wb') as temp_file:
            # Write sample rate
            temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
            # Compress and write the tokens data
            compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
            temp_file.write(compressed_data)

        return temp_file_path

    except Exception as e:
        print(f"Encoding error: {e}")
        return None  # Return None instead of the error message

    finally:
        cancel_encode = False  # Reset cancel flag after encoding
        if 'temp_wav_file_path' in locals():
            os.remove(temp_wav_file_path)  # Clean up temporary WAV file

# Add this function to handle the output
def handle_encode_output(file_path):
    if file_path is None:
        return None, gr.Markdown("Encoding failed. Please ensure you've uploaded an audio file and try again.", visible=True)
    return file_path, gr.Markdown(visible=False)

@spaces.GPU(duration=30)  # Changed from 250 to 30
def decode_audio(encoded_file_path):
    global cancel_decode

    try:
        # Load encoded data and sample rate
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)

            # Reshape tokens to match the original shape
            tokens_numpy = tokens_numpy.reshape(1, -1, 2)

            # Create a writable copy of the numpy array
            tokens_numpy = np.array(tokens_numpy, copy=True)

            # Move the tensor to the same device as the model
            tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)

        # Debugging prints to check tensor shapes and device
        print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
        print(f"Model device: {semanticodec.device}")

        # Decode the audio
        with torch.no_grad():
            waveform = semanticodec.decode(tokens)

        # Save to a temporary WAV file
        temp_wav_path = tempfile.mktemp(suffix=".wav")
        torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
        return temp_wav_path

    except Exception as e:
        print(f"Decoding error: {e}")
        print(f"Traceback: {traceback.format_exc()}")
        return str(e)  # Return error message as string

    finally:
        cancel_decode = False  # Reset cancel flag after decoding

@spaces.GPU(duration=30)  # Changed from 250 to 30
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
    global cancel_stream

    try:
        # Load encoded data and sample rate from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
            tokens_numpy = tokens_numpy.reshape(1, -1, 2)
            
            # Create a writable copy of the numpy array
            tokens_numpy = np.array(tokens_numpy, copy=True)
            
            # Move the tensor to the same device as the model
            tokens = torch.from_numpy(tokens_numpy).to(device=semanticodec.device)

        print(f"Streaming tokens shape: {tokens.shape}, dtype: {tokens.dtype}, device: {tokens.device}")
        print(f"Model device: {semanticodec.device}")

        # Decode the audio in chunks
        chunk_size = sample_rate // 2  # Adjust chunk size to account for the new shape
        with torch.no_grad():
            for i in range(0, tokens.shape[1], chunk_size):
                if cancel_stream:
                    break  # Exit the loop if cancellation is requested

                tokens_chunk = tokens[:, i:i+chunk_size, :]
                audio_chunk = semanticodec.decode(tokens_chunk)
                # Convert to numpy array and transpose
                audio_data = audio_chunk.squeeze(0).cpu().numpy().T
                yield (sample_rate, audio_data)
                await asyncio.sleep(0)  # Allow for cancellation check

    except Exception as e:
        print(f"Streaming decoding error: {e}")
        print(f"Traceback: {traceback.format_exc()}")
        yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32))  # Return silence

    finally:
        cancel_stream = False  # Reset cancel flag after streaming

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")

    with gr.Tab("Encode"):
        input_audio = gr.Audio(label="Input Audio", type="filepath")  # Using "filepath" mode
        encode_button = gr.Button("Encode")
        cancel_encode_button = gr.Button("Cancel")
        encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        encode_error_message = gr.Markdown(visible=False)

        def encode_wrapper(audio):
            if audio is None:
                return None, gr.Markdown("Please upload an audio file before encoding.", visible=True)
            return handle_encode_output(encode_audio(audio))

        encode_button.click(
            encode_wrapper,
            inputs=input_audio,
            outputs=[encoded_output, encode_error_message]
        )
        cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)  # Set cancel_encode flag

    with gr.Tab("Decode"):
        input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        decode_button = gr.Button("Decode")
        cancel_decode_button = gr.Button("Cancel")
        decoded_output = gr.Audio(label="Decoded Audio", type="filepath")  # Using "filepath" mode

        decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
        cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)  # Set cancel_decode flag

    with gr.Tab("Streaming"):
        input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        stream_button = gr.Button("Start Streaming")
        cancel_stream_button = gr.Button("Cancel")
        audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)

        stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
        cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)  # Set cancel_stream flag

demo.queue().launch()