Apollo / app.py
Serhiy Stetskovych
Add universal model and updated texts.
import torch
import numpy as np
import gradio as gr
import yaml
import librosa
from tqdm.auto import tqdm
import spaces
import look2hear.models
from ml_collections import ConfigDict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_audio(file_path):
audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
#audio = dBgain(audio, -6)
return torch.from_numpy(audio), samplerate
def get_config(config_path):
with open(config_path) as f:
#config = OmegaConf.load(config_path)
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
return config
def _getWindowingArray(window_size, fade_size):
# no fades here in the end, only removing the failed ending of the chunk
fadein = torch.linspace(1, 1, fade_size)
fadeout = torch.linspace(0, 0, fade_size)
window = torch.ones(window_size)
window[-fade_size:] *= fadeout
window[:fade_size] *= fadein
return window
description = f'''
This is unofficial space for audio restoration model Apollo: https://github.com/JusperLee/Apollo
apollo_config = get_config('configs/apollo.yaml')
apollo_vocal2_config = get_config('configs/config_apollo_vocal.yaml')
apollo_uni_config = get_config('configs/config_apollo_uni.yaml')
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device)
apollo_vocal = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal.bin', **apollo_config['model']).to(device)
apollo_vocal2 = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal2.bin', **apollo_vocal2_config['model']).to(device)
apollo_uni = look2hear.models.BaseModel.from_pretrain('weights/apollo_model_uni.ckpt', **apollo_uni_config['model']).to(device)
models = {
'apollo': apollo_model,
'apollo_vocal': apollo_vocal,
'apollo_vocal2': apollo_vocal2,
'apollo_uni': apollo_uni
choices = [
('MP3 restore', 'apollo'),
('Apollo vocal', 'apollo_vocal'),
('Apollo vocal2', 'apollo_vocal2'),
('Apollo universal', 'apollo_uni')
def enchance(choice, audio):
model = models[choice]
test_data, samplerate = load_audio(audio)
C = 10 * samplerate # chunk_size seconds to samples
N = 2
step = C // N
fade_size = 3 * 44100 # 3 seconds
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
border = C - step
# handle mono inputs correctly
if len(test_data.shape) == 1:
test_data = test_data.unsqueeze(0)
# Pad the input if necessary
if test_data.shape[1] > 2 * border and (border > 0):
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
windowingArray = _getWindowingArray(C, fade_size)
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
i = 0
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
while i < test_data.shape[1]:
part = test_data[:, i:i + C]
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
chunk = part.unsqueeze(0).to(device)
with torch.no_grad():
out = model(chunk).squeeze(0).squeeze(0).cpu()
window = windowingArray
if i == 0: # First audio chunk, no fadein
window[:fade_size] = 1
elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
window[-fade_size:] = 1
result[..., i:i+length] += out[..., :length] * window[..., :length]
counter[..., i:i+length] += window[..., :length]
i += step
final_output = result / counter
final_output = final_output.squeeze(0).numpy()
np.nan_to_num(final_output, copy=False, nan=0.0)
# Remove padding if added earlier
if test_data.shape[1] > 2 * border and (border > 0):
final_output = final_output[..., border:-border]
return samplerate, final_output.T
if __name__ == "__main__":
i = gr.Interface(
gr.Dropdown(label="Model", choices=choices, value=choices[0][1]),
gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=3000, waveform_options={'waveform_progress_color': '#3C82F6'}),
label="Output Audio",
allow_flagging ='never',
title='Apollo audio restoration',
i.queue(max_size=20, default_concurrency_limit=4)
i.launch(share=False, server_name="")