File size: 4,466 Bytes
f18f98b
eb0f782
 
 
a2cc897
bd40662
eb0f782
 
 
 
bd40662
eb0f782
 
 
 
 
 
 
c27eb74
eb0f782
 
 
a7dbbfe
eb0f782
c9a89ac
eb0f782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27eb74
 
eb0f782
c27eb74
eb0f782
 
c27eb74
eb0f782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27eb74
eb0f782
 
a7dbbfe
eb0f782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27eb74
 
eb0f782
 
c27eb74
eb0f782
d07be48
eb0f782
44bab11
f18f98b
eb0f782
 
 
44bab11
eb0f782
44bab11
f18f98b
eb0f782
 
 
 
 
 
 
 
 
 
44bab11
eb0f782
f18f98b
c9a89ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
import torchaudio
from agc import AGC
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import spaces

# Attempt to use GPU, fallback to CPU
try:
    torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {torch_device}")
except Exception as e:
    print(f"Error detecting GPU. Using CPU. Error: {e}")
    torch_device = torch.device("cpu")

# Load the AGC model
def load_agc_model():
    return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)

agc = load_agc_model()

@spaces.GPU(duration=180)
def encode_audio(audio_file_path):
    try:
        # Load the audio file
        waveform, sample_rate = torchaudio.load(audio_file_path)

        # Convert to stereo if necessary
        if waveform.size(0) == 1:
            waveform = waveform.repeat(2, 1)

        # Encode the audio
        audio = waveform.unsqueeze(0).to(torch_device)
        with torch.no_grad():
            z = agc.encode(audio)

        # Convert to NumPy and save to a temporary .owie file
        z_numpy = z.detach().cpu().numpy()
        temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
        os.close(temp_fd)  # Close the file descriptor to avoid issues with os.fdopen
        with open(temp_file_path, 'wb') as temp_file:
            compressed_data = lz4.frame.compress(z_numpy.tobytes())
            temp_file.write(compressed_data)
        
        return temp_file_path

    except Exception as e:
        return f"Encoding error: {e}"

@spaces.GPU(duration=180)
def decode_audio(encoded_file_path):
    try:
        # Load encoded data from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            compressed_data = temp_file.read()
            z_numpy_bytes = lz4.frame.decompress(compressed_data)
            z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
            z = torch.from_numpy(z_numpy).to(torch_device)

        # Decode the audio
        with torch.no_grad():
            reconstructed_audio = agc.decode(z)

        # Save to a temporary WAV file
        temp_wav_path = tempfile.mktemp(suffix=".wav")
        torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), sample_rate)
        return temp_wav_path

    except Exception as e:
        return f"Decoding error: {e}"

@spaces.GPU(duration=180)
def stream_decode_audio(encoded_file_path) -> Generator[np.ndarray, None, None]:
    try:
        # Load encoded data from the .owie file
        with open(encoded_file_path, 'rb') as temp_file:
            compressed_data = temp_file.read()
            z_numpy_bytes = lz4.frame.decompress(compressed_data)
            z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
            z = torch.from_numpy(z_numpy).to(torch_device)

        # Decode the audio in chunks
        chunk_size = 16000  # 1 second of audio at 16kHz
        with torch.no_grad():
            for i in range(0, z.shape[2], chunk_size):
                z_chunk = z[:, :, i:i+chunk_size]
                audio_chunk = agc.decode(z_chunk)
                yield audio_chunk.squeeze(0).cpu().numpy()

    except Exception as e:
        yield np.zeros((2, chunk_size))  # Return silence in case of error
        print(f"Streaming decoding error: {e}")

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Audio Compression with AGC (GPU/CPU)")

    with gr.Tab("Encode"):
        input_audio = gr.Audio(label="Input Audio", type="filepath")
        encode_button = gr.Button("Encode")
        encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")

        encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)

    with gr.Tab("Decode"):
        input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
        decode_button = gr.Button("Decode")
        decoded_output = gr.Audio(label="Decoded Audio", type="filepath")

        decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)

    with gr.Tab("Streaming"):
        input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
        stream_button = gr.Button("Start Streaming")
        audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)

        stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)

demo.queue().launch()