File size: 6,677 Bytes
f18f98b
6fa15d7
eb0f782
 
763a29b
a2cc897
bd40662
eb0f782
 
 
763a29b
bd40662
eb0f782
 
 
 
 
 
 
c27eb74
763a29b
 
a7dbbfe
763a29b
 
 
 
c9a89ac
6eabaea
eb0f782
763a29b
4ca3581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763a29b
6eabaea
eb0f782
763a29b
4ca3581
 
 
 
 
 
 
 
 
d87908f
 
 
 
 
4ca3581
 
 
 
 
 
 
 
 
763a29b
6eabaea
763a29b
 
eb0f782
d24bcbd
eb0f782
d24bcbd
eb0f782
763a29b
4ca3581
763a29b
eb0f782
47c28f2
d87908f
 
47c28f2
eb0f782
d24bcbd
eb0f782
763a29b
 
 
 
 
 
d888fa7
 
 
763a29b
c27eb74
 
eb0f782
763a29b
c27eb74
763a29b
 
d24bcbd
eb0f782
d07be48
763a29b
44bab11
f18f98b
6eabaea
eb0f782
763a29b
6eabaea
44bab11
eb0f782
763a29b
 
44bab11
f18f98b
6eabaea
eb0f782
763a29b
6eabaea
eb0f782
 
763a29b
 
eb0f782
 
6eabaea
eb0f782
763a29b
d888fa7
44bab11
eb0f782
763a29b
 
f18f98b
44da512
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import gradio as gr
import spaces
import torch
import torchaudio
from semanticodec import SemantiCodec
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import asyncio  # Import asyncio for cancellation

# Attempt to use GPU, fallback to CPU
try:
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {torch_device}")
except Exception as e:
    print(f"Error detecting GPU. Using CPU. Error: {e}")
    torch_device = torch.device("cpu")

# Load the SemantiCodec model
semanticodec = SemantiCodec(token_rate=100, semantic_vocab_size=32768).to(torch_device)

# Global variable for cancellation
cancel_encode = False
cancel_decode = False
cancel_stream = False

@spaces.GPU(duration=250)
def encode_audio(audio_file_path):
    global cancel_encode
    # Load the audio file and convert it to WAV format
    waveform, sample_rate = torchaudio.load(audio_file_path)
    
    # Ensure waveform has the right dimensions
    if waveform.ndim == 1:
        waveform = waveform.unsqueeze(0)

    # Create a temporary WAV file
    temp_wav_fd, temp_wav_file_path = tempfile.mkstemp(suffix=".wav")
    os.close(temp_wav_fd)
    torchaudio.save(temp_wav_file_path, waveform, sample_rate)

    # Encode the audio using the WAV file path
    tokens = semanticodec.encode(temp_wav_file_path)

    # Convert to NumPy and save to a temporary .owie file
    tokens_numpy = tokens.detach().cpu().numpy()
    temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
    os.close(temp_fd)
    with open(temp_file_path, 'wb') as temp_file:
        # Store the sample rate as the first 4 bytes
        temp_file.write(sample_rate.to_bytes(4, byteorder='little'))
        # Compress and write the encoded data
        compressed_data = lz4.frame.compress(tokens_numpy.tobytes())
        temp_file.write(compressed_data)

    return temp_file_path

@spaces.GPU(duration=250)
def decode_audio(encoded_file_path):
    global cancel_decode
    # Load encoded data and sample rate from the .owie file
    with open(encoded_file_path, 'rb') as temp_file:
        sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
        compressed_data = temp_file.read()
        tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
        tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)  # Ensure dtype matches encoder's output
        tokens = torch.from_numpy(tokens_numpy).to(torch_device)

    # Ensure tokens has the right dimensions
    if tokens.ndimension() == 2:  # If tokens have only 2 dimensions
        tokens = tokens.unsqueeze(0)  # Add batch dimension

    # Debugging prints to check tensor shapes
    print(f"Tokens shape: {tokens.shape}, dtype: {tokens.dtype}")

    # Decode the audio
    with torch.no_grad():
        waveform = semanticodec.decode(tokens)

    # Save to a temporary WAV file
    temp_wav_path = tempfile.mktemp(suffix=".wav")
    torchaudio.save(temp_wav_path, waveform.squeeze(0).cpu(), sample_rate)
    return temp_wav_path

@spaces.GPU(duration=250)
async def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
    global cancel_stream
    try:
        # Load encoded data and sample rate from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            sample_rate = int.from_bytes(temp_file.read(4), byteorder='little')
            compressed_data = temp_file.read()
            tokens_numpy_bytes = lz4.frame.decompress(compressed_data)
            tokens_numpy = np.frombuffer(tokens_numpy_bytes, dtype=np.int64)  # Ensure dtype matches encoder's output
            tokens = torch.from_numpy(tokens_numpy).to(torch_device)

        # Ensure tokens has the right dimensions
        if tokens.ndimension() == 2:  # If tokens have only 2 dimensions
            tokens = tokens.unsqueeze(0)  # Add batch dimension

        # Decode the audio in chunks
        chunk_size = sample_rate  # Use the stored sample rate as chunk size
        with torch.no_grad():
            for i in range(0, tokens.shape[1], chunk_size):
                if cancel_stream:
                    break  # Exit the loop if cancellation is requested

                tokens_chunk = tokens[:, i:i+chunk_size]
                audio_chunk = semanticodec.decode(tokens_chunk)
                # Convert to numpy array and transpose
                audio_data = audio_chunk.squeeze(0).cpu().numpy().T
                yield (sample_rate, audio_data)
                await asyncio.sleep(0)  # Allow for cancellation check

    except Exception as e:
        print(f"Streaming decoding error: {e}")
        yield (sample_rate, np.zeros((chunk_size, 1), dtype=np.float32))  # Return silence

    finally:
        cancel_stream = False  # Reset cancel flag after streaming

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Audio Compression with SemantiCodec (GPU/CPU)")

    with gr.Tab("Encode"):
        input_audio = gr.Audio(label="Input Audio", type="filepath")  # Using "filepath" mode
        encode_button = gr.Button("Encode")
        cancel_encode_button = gr.Button("Cancel")
        encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode

        encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
        cancel_encode_button.click(lambda: globals().update(cancel_encode=True),
                                     outputs=None)  # Set cancel_encode flag

    with gr.Tab("Decode"):
        input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        decode_button = gr.Button("Decode")
        cancel_decode_button = gr.Button("Cancel")
        decoded_output = gr.Audio(label="Decoded Audio", type="filepath")  # Using "filepath" mode

        decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
        cancel_decode_button.click(lambda: globals().update(cancel_decode=True),
                                     outputs=None)  # Set cancel_decode flag

    with gr.Tab("Streaming"):
        input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")  # Using "filepath" mode
        stream_button = gr.Button("Start Streaming")
        cancel_stream_button = gr.Button("Cancel")
        audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)

        stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
        cancel_stream_button.click(lambda: globals().update(cancel_stream=True),
                                     outputs=None)  # Set cancel_stream flag

demo.queue().launch()