File size: 6,252 Bytes
f18f98b
eb0f782
 
763a29b
a2cc897
bd40662
eb0f782
 
 
 
763a29b
bd40662
eb0f782
 
 
 
 
 
 
c27eb74
763a29b
 
a7dbbfe
763a29b
 
 
 
c9a89ac
763a29b
eb0f782
763a29b
eb0f782
 
 
 
 
 
 
763a29b
eb0f782
 
763a29b
eb0f782
763a29b
eb0f782
d24bcbd
 
 
763a29b
eb0f782
d24bcbd
eb0f782
c27eb74
 
eb0f782
c27eb74
763a29b
 
 
 
eb0f782
763a29b
c27eb74
d24bcbd
eb0f782
d24bcbd
eb0f782
763a29b
 
 
eb0f782
 
 
763a29b
eb0f782
 
 
763a29b
eb0f782
c27eb74
eb0f782
 
a7dbbfe
763a29b
 
 
 
 
 
eb0f782
d24bcbd
eb0f782
d24bcbd
eb0f782
763a29b
 
 
eb0f782
 
d24bcbd
eb0f782
763a29b
 
 
 
 
 
d888fa7
 
 
763a29b
c27eb74
 
eb0f782
763a29b
c27eb74
763a29b
 
d24bcbd
eb0f782
d07be48
763a29b
44bab11
f18f98b
eb0f782
 
763a29b
eb0f782
44bab11
eb0f782
763a29b
 
44bab11
f18f98b
eb0f782
 
763a29b
eb0f782
 
 
763a29b
 
eb0f782
 
 
 
763a29b
d888fa7
44bab11
eb0f782
763a29b
 
f18f98b
c9a89ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import spaces
import asyncio  # Import asyncio for cancellation

# Attempt to use GPU, fallback to CPU
try:
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {torch_device}")
except Exception as e:
    print(f"Error detecting GPU. Using CPU. Error: {e}")
    torch_device = torch.device("cpu")

# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)

# Global variable for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False

@spaces.GPU(duration=500)  # Increased GPU duration to 500 seconds
def encode_audio(audio_file_path):
    global cancel_encode
    try:
        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Encode the audio
        audio = waveform.unsqueeze(0).to(torch_device)
        with torch.no_grad():
            tokens = semanticodec.encode(audio)

        # Convert to NumPy and save to a temporary .owie file
        tokens_numpy = tokens.detach().cpu().numpy()
        temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
        os.close(temp_fd)
        with open(temp_file_path, 'wb') as temp_file:
            # Store the sample rate as the first 4 bytes
            temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
            # Compress and write the encoded data
            compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
            temp_file.write(compressed_data)

        return temp_file_path

    except Exception as e:
        return f"Encoding error: {e}"

    finally:
        cancel_encode = False  # Reset cancel flag after encoding

@spaces.GPU(duration=500)  # Increased GPU duration to 500 seconds
def decode_audio(encoded_file_path):
    global cancel_decode
    try:
        # Load encoded data and sample rate from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
            tokens = torch.from_numpy(tokens_numpy).to(torch_device)

        # Decode the audio
        with torch.no_grad():
            waveform = semanticodec.decode(tokens)

        # Save to a temporary WAV file
        temp_wav_path = tempfile.mktemp(suffix=".wav")
        torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
        return temp_wav_path

    except Exception as e:
        return f"Decoding error: {e}"

    finally:
        cancel_decode = False  # Reset cancel flag after decoding

@spaces.GPU(duration=500)  # Increased GPU duration to 500 seconds
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
    global cancel_stream
    try:
        # Load encoded data and sample rate from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)
            tokens = torch.from_numpy(tokens_numpy).to(torch_device)

        # Decode the audio in chunks
        chunk_size = sample_rate  # Use the stored sample rate as chunk size
        with torch.no_grad():
            for i in range(0, tokens.shape[1], chunk_size):
                if cancel_stream:
                    break  # Exit the loop if cancellation is requested

                tokens_chunk = tokens[:, i:i+chunk_size]
                audio_chunk = semanticodec.decode(tokens_chunk)
                # Convert to numpy array and transpose
                audio_data = audio_chunk.squeeze(0).cpu().numpy().T
                yield (sample_rate, audio_data)
                await asyncio.sleep(0)  # Allow for cancellation check

    except Exception as e:
        print(f"Streaming decoding error: {e}")
        yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32))  # Return silence

    finally:
        cancel_stream = False  # Reset cancel flag after streaming

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")

    with gr.Tab("Encode"):
        input_audio = gr.Audio(label="Input Audio", type="filepath")
        encode_button = gr.Button("Encode")
        cancel_encode_button = gr.Button("Cancel")
        encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")

        encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
        cancel_encode_button.click(lambda: globals().update(cancel_encode=True),
                                     outputs=None)  # Set cancel_encode flag

    with gr.Tab("Decode"):
        input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
        decode_button = gr.Button("Decode")
        cancel_decode_button = gr.Button("Cancel")
        decoded_output = gr.Audio(label="Decoded Audio", type="filepath")

        decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
        cancel_decode_button.click(lambda: globals().update(cancel_decode=True),
                                     outputs=None)  # Set cancel_decode flag

    with gr.Tab("Streaming"):
        input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
        stream_button = gr.Button("Start Streaming")
        cancel_stream_button = gr.Button("Cancel")
        audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)

        stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
        cancel_stream_button.click(lambda: globals().update(cancel_stream=True),
                                     outputs=None)  # Set cancel_stream flag

demo.queue().launch()