File size: 4,991 Bytes
f18f98b eb0f782 a2cc897 bd40662 eb0f782 bd40662 eb0f782 c27eb74 eb0f782 0b50165 eb0f782 a7dbbfe eb0f782 c9a89ac eb0f782 0b50165 eb0f782 c27eb74 eb0f782 c27eb74 eb0f782 c27eb74 eb0f782 0b50165 eb0f782 c27eb74 eb0f782 a7dbbfe eb0f782 3226a34 eb0f782 0b50165 eb0f782 0b50165 3226a34 c27eb74 eb0f782 0b50165 c27eb74 eb0f782 d07be48 0b50165 44bab11 f18f98b eb0f782 44bab11 eb0f782 44bab11 f18f98b eb0f782 44bab11 eb0f782 f18f98b c9a89ac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
import torch
import torchaudio
from agc import AGC
import tempfile
import numpy as np
import lz4.frame
import os
from typing import Generator
import spaces
# Attempt to use GPU, fallback to CPU
try:
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {torch_device}")
except Exception as e:
print(f"Error detecting GPU. Using CPU. Error: {e}")
torch_device = torch.device("cpu")
# Load the AGC model
@spaces.GPU(duration=180)
def load_agc_model():
return AGC.from_pretrained("Audiogen/agc-continuous").to(torch_device)
agc = load_agc_model()
@spaces.GPU(duration=180)
def encode_audio(audio_file_path):
try:
# Load the audio file
waveform, sample_rate = torchaudio.load(audio_file_path)
# Resample to 32kHz if necessary
if sample_rate != 32000:
resampler = torchaudio.transforms.Resample(sample_rate, 32000)
waveform = resampler(waveform)
# Convert to 32 channels if necessary
if waveform.size(0) < 32:
waveform = waveform.repeat(32, 1)[:32, :]
elif waveform.size(0) > 32:
waveform = waveform[:32, :]
# Encode the audio
audio = waveform.unsqueeze(0).to(torch_device)
with torch.no_grad():
z = agc.encode(audio)
# Convert to NumPy and save to a temporary .owie file
z_numpy = z.detach().cpu().numpy()
temp_fd, temp_file_path = tempfile.mkstemp(suffix=".owie")
os.close(temp_fd) # Close the file descriptor to avoid issues with os.fdopen
with open(temp_file_path, 'wb') as temp_file:
compressed_data = lz4.frame.compress(z_numpy.tobytes())
temp_file.write(compressed_data)
return temp_file_path
except Exception as e:
return f"Encoding error: {e}"
@spaces.GPU(duration=180)
def decode_audio(encoded_file_path):
try:
# Load encoded data from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
compressed_data = temp_file.read()
z_numpy_bytes = lz4.frame.decompress(compressed_data)
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
z = torch.from_numpy(z_numpy).to(torch_device)
# Decode the audio
with torch.no_grad():
reconstructed_audio = agc.decode(z)
# Save to a temporary WAV file
temp_wav_path = tempfile.mktemp(suffix=".wav")
torchaudio.save(temp_wav_path, reconstructed_audio.squeeze(0).cpu(), 32000)
return temp_wav_path
except Exception as e:
return f"Decoding error: {e}"
@spaces.GPU(duration=180)
def stream_decode_audio(encoded_file_path) -> Generator[tuple, None, None]:
try:
# Load encoded data from the .owie file
with open(encoded_file_path, 'rb') as temp_file:
compressed_data = temp_file.read()
z_numpy_bytes = lz4.frame.decompress(compressed_data)
z_numpy = np.frombuffer(z_numpy_bytes, dtype=np.float32).reshape(1, 32, -1)
z = torch.from_numpy(z_numpy).to(torch_device)
# Decode the audio in chunks
chunk_size = 32000 # 1 second of audio at 32kHz
sample_rate = 32000 # AGC model's output sample rate
with torch.no_grad():
for i in range(0, z.shape[2], chunk_size):
z_chunk = z[:, :, i:i+chunk_size]
audio_chunk = agc.decode(z_chunk)
# Convert to numpy array (32 channels)
audio_data = audio_chunk.squeeze(0).cpu().numpy()
yield (sample_rate, audio_data)
except Exception as e:
print(f"Streaming decoding error: {e}")
yield (sample_rate, np.zeros((32, chunk_size), dtype=np.float32)) # Return silence in case of error
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Audio Compression with AGC (GPU/CPU) - 32 channels, 32kHz")
with gr.Tab("Encode"):
input_audio = gr.Audio(label="Input Audio", type="filepath")
encode_button = gr.Button("Encode")
encoded_output = gr.File(label="Encoded File (.owie)", type="filepath")
encode_button.click(encode_audio, inputs=input_audio, outputs=encoded_output)
with gr.Tab("Decode"):
input_encoded = gr.File(label="Encoded File (.owie)", type="filepath")
decode_button = gr.Button("Decode")
decoded_output = gr.Audio(label="Decoded Audio", type="filepath")
decode_button.click(decode_audio, inputs=input_encoded, outputs=decoded_output)
with gr.Tab("Streaming"):
input_encoded_stream = gr.File(label="Encoded File (.owie)", type="filepath")
stream_button = gr.Button("Start Streaming")
audio_output = gr.Audio(label="Streaming Audio Output", streaming=True)
stream_button.click(stream_decode_audio, inputs=input_encoded_stream, outputs=audio_output)
demo.queue().launch() |