File size: 7,815 Bytes
f18f98b
6fa15d7
eb0f782
 
763a29b
a2cc897
bd40662
eb0f782
 
 
763a29b
bd40662
eb0f782
 
 
 
 
 
 
c27eb74
763a29b
 
a7dbbfe
50a007c
763a29b
 
 
c9a89ac
6eabaea
eb0f782
763a29b
841c4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fdeaf
 
 
 
 
841c4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73fdeaf
841c4fd
 
 
73fdeaf
 
 
 
 
 
 
 
763a29b
6eabaea
eb0f782
763a29b
4ca3581
841c4fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
086a0ea
841c4fd
 
 
 
d87908f
841c4fd
 
 
4ca3581
841c4fd
 
 
 
4ca3581
841c4fd
 
50a007c
841c4fd
 
 
763a29b
6eabaea
763a29b
 
841c4fd
eb0f782
d24bcbd
eb0f782
d24bcbd
eb0f782
763a29b
841c4fd
763a29b
eb0f782
 
d24bcbd
eb0f782
763a29b
 
 
 
 
 
d888fa7
 
 
763a29b
c27eb74
 
eb0f782
763a29b
c27eb74
763a29b
 
d24bcbd
eb0f782
d07be48
763a29b
44bab11
f18f98b
6eabaea
eb0f782
763a29b
6eabaea
73fdeaf
44bab11
73fdeaf
 
 
 
 
841c4fd
44bab11
f18f98b
6eabaea
eb0f782
763a29b
6eabaea
eb0f782
 
841c4fd
eb0f782
 
6eabaea
eb0f782
763a29b
d888fa7
44bab11
eb0f782
841c4fd
f18f98b
73fdeaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio  # Import asyncio for cancellation

# Attempt to use GPU, fallback to CPU
try:
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {torch_device}")
except Exception as e:
    print(f"Error detecting GPU. Using CPU. Error: {e}")
    torch_device = torch.device("cpu")

# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)

# Global variables for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False

@spaces.GPU(duration=250)
def encode_audio(audio_file_path):
    global cancel_encode

    try:
        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Ensure waveform has the correct number of dimensions
        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        # Save to a temporary WAV file
        temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
        os.close(temp_wav_fd)
        torchaudio.save(temp_wav_file_path, waveform, sample_rate)

        # Encode the audio
        tokens = semanticodec.encode(temp_wav_file_path)

        # Convert tokens to NumPy and save to .owie file
        tokens_numpy = tokens.detach().cpu().numpy()

        # Ensure tokens_numpy is 2D
        if tokens_numpy.ndim == 1:
            tokens_numpy = tokens_numpy.reshape(1, -1)
        elif tokens_numpy.ndim > 2:
            raise ValueError("Tokens array must be 1D or 2D")

        # Create temporary .owie file
        temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
        os.close(temp_fd)
        with open(temp_file_path, 'wb') as temp_file:
            # Write sample rate
            temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
            # Compress and write the tokens data
            compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
            temp_file.write(compressed_data)

        return temp_file_path

    except Exception as e:
        print(f"Encoding error: {e}")
        return None  # Return None instead of the error message

    finally:
        cancel_encode = False  # Reset cancel flag after encoding
        if 'temp_wav_file_path' in locals():
            os.remove(temp_wav_file_path)  # Clean up temporary WAV file

# Add this function to handle the output
def handle_encode_output(file_path):
    if file_path is None:
        return None, gr.Markdown("Encoding failed. Please check the input file and try again.", visible=True)
    return file_path, gr.Markdown(visible=False)

@spaces.GPU(duration=250)
def decode_audio(encoded_file_path):
    global cancel_decode

    try:
        # Load encoded data and sample rate
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)

            # Check and reshape tokens
            if tokens_numpy.ndim == 1:
                tokens_numpy = tokens_numpy.reshape(1, -1)  # Reshape to [1, token_length]
            elif tokens_numpy.ndim == 2:
                tokens_numpy = tokens_numpy.reshape(1, tokens_numpy.shape[1])  # Ensure 2D tensor
            else:
                raise ValueError("Tokens array must be 1D or 2D")

            tokens = torch.from_numpy(tokens_numpy).to(torch_device)

        # Debugging prints to check tensor shapes
        print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")

        # Decode the audio
        with torch.no_grad():
            waveform = semanticodec.decode(tokens)

        # Save to a temporary WAV file
        temp_wav_path = tempfile.mktemp(suffix=".wav")
        torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
        return temp_wav_path

    except Exception as e:
        print(f"Decoding error: {e}")
        return str(e)  # Return error message as string

    finally:
        cancel_decode = False  # Reset cancel flag after decoding

@spaces.GPU(duration=250)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
    global cancel_stream

    try:
        # Load encoded data and sample rate from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
            tokens = torch.from_numpy(tokens_numpy).to(torch_device)

        # Decode the audio in chunks
        chunk_size = sample_rate  # Use the stored sample rate as chunk size
        with torch.no_grad():
            for i in range(0, tokens.shape[1], chunk_size):
                if cancel_stream:
                    break  # Exit the loop if cancellation is requested

                tokens_chunk = tokens[:, i:i+chunk_size]
                audio_chunk = semanticodec.decode(tokens_chunk)
                # Convert to numpy array and transpose
                audio_data = audio_chunk.squeeze(0).cpu().numpy().T
                yield (sample_rate, audio_data)
                await asyncio.sleep(0)  # Allow for cancellation check

    except Exception as e:
        print(f"Streaming decoding error: {e}")
        yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32))  # Return silence

    finally:
        cancel_stream = False  # Reset cancel flag after streaming

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")

    with gr.Tab("Encode"):
        input_audio = gr.Audio(label="Input Audio", type="filepath")  # Using "filepath" mode
        encode_button = gr.Button("Encode")
        cancel_encode_button = gr.Button("Cancel")
        encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        encode_error_message = gr.Markdown(visible=False)

        encode_button.click(
            lambda x: handle_encode_output(encode_audio(x)),
            inputs=input_audio,
            outputs=[encoded_output, encode_error_message]
        )
        cancel_encode_button.click(lambda: globals().update(cancel_encode=True), outputs=None)  # Set cancel_encode flag

    with gr.Tab("Decode"):
        input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        decode_button = gr.Button("Decode")
        cancel_decode_button = gr.Button("Cancel")
        decoded_output = gr.Audio(label="Decoded Audio", type="filepath")  # Using "filepath" mode

        decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
        cancel_decode_button.click(lambda: globals().update(cancel_decode=True), outputs=None)  # Set cancel_decode flag

    with gr.Tab("Streaming"):
        input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        stream_button = gr.Button("Start Streaming")
        cancel_stream_button = gr.Button("Cancel")
        audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)

        stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
        cancel_stream_button.click(lambda: globals().update(cancel_stream=True), outputs=None)  # Set cancel_stream flag

demo.queue().launch()