Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,277 Bytes
78e32cc ea8d6db 72b5eb1 78e32cc 72b5eb1 98a6a49 78e32cc 701d19a 78e32cc e42e488 701d19a 98a6a49 e42e488 701d19a a60a79a 78e32cc 1b8633f a60a79a 701d19a 1b8633f a60a79a 701d19a 78e32cc 1b8633f 78e32cc 98a6a49 78e32cc a161d55 78e32cc 701d19a 78e32cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import torch
import numpy as np
import gradio as gr
import yaml
import librosa
from tqdm.auto import tqdm
import spaces
import look2hear.models
from ml_collections import ConfigDict
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def load_audio(file_path):
audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
#audio = dBgain(audio, -6)
return torch.from_numpy(audio), samplerate
def get_config(config_path):
with open(config_path) as f:
#config = OmegaConf.load(config_path)
config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
return config
def _getWindowingArray(window_size, fade_size):
# IMPORTANT NOTE :
# no fades here in the end, only removing the failed ending of the chunk
fadein = torch.linspace(1, 1, fade_size)
fadeout = torch.linspace(0, 0, fade_size)
window = torch.ones(window_size)
window[-fade_size:] *= fadeout
window[:fade_size] *= fadein
return window
description = f'''
This is unofficial space for audio restoration model Apollo: https://github.com/JusperLee/Apollo
'''
apollo_config = get_config('configs/apollo.yaml')
apollo_vocal2_config = get_config('configs/config_apollo_vocal.yaml')
apollo_uni_config = get_config('configs/config_apollo_uni.yaml')
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device)
apollo_vocal = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal.bin', **apollo_config['model']).to(device)
apollo_vocal2 = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal2.bin', **apollo_vocal2_config['model']).to(device)
apollo_uni = look2hear.models.BaseModel.from_pretrain('weights/apollo_model_uni.ckpt', **apollo_uni_config['model']).to(device)
models = {
'apollo': apollo_model,
'apollo_vocal': apollo_vocal,
'apollo_vocal2': apollo_vocal2,
'apollo_uni': apollo_uni
}
choices = [
('MP3 restore', 'apollo'),
('Apollo vocal', 'apollo_vocal'),
('Apollo vocal2', 'apollo_vocal2'),
('Apollo universal', 'apollo_uni')
]
@spaces.GPU
def enchance(choice, audio):
print(choice)
model = models[choice]
test_data, samplerate = load_audio(audio)
C = 10 * samplerate # chunk_size seconds to samples
N = 2
step = C // N
fade_size = 3 * 44100 # 3 seconds
print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
border = C - step
# handle mono inputs correctly
if len(test_data.shape) == 1:
test_data = test_data.unsqueeze(0)
# Pad the input if necessary
if test_data.shape[1] > 2 * border and (border > 0):
test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')
windowingArray = _getWindowingArray(C, fade_size)
result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
i = 0
progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)
while i < test_data.shape[1]:
part = test_data[:, i:i + C]
length = part.shape[-1]
if length < C:
if length > C // 2 + 1:
part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
else:
part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)
chunk = part.unsqueeze(0).to(device)
with torch.no_grad():
out = model(chunk).squeeze(0).squeeze(0).cpu()
window = windowingArray
if i == 0: # First audio chunk, no fadein
window[:fade_size] = 1
elif i + C >= test_data.shape[1]: # Last audio chunk, no fadeout
window[-fade_size:] = 1
result[..., i:i+length] += out[..., :length] * window[..., :length]
counter[..., i:i+length] += window[..., :length]
i += step
progress_bar.update(step)
progress_bar.close()
final_output = result / counter
final_output = final_output.squeeze(0).numpy()
np.nan_to_num(final_output, copy=False, nan=0.0)
# Remove padding if added earlier
if test_data.shape[1] > 2 * border and (border > 0):
final_output = final_output[..., border:-border]
return samplerate, final_output.T
if __name__ == "__main__":
i = gr.Interface(
fn=enchance,
description=description,
inputs=[
gr.Dropdown(label="Model", choices=choices, value=choices[0][1]),
gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=3000, waveform_options={'waveform_progress_color': '#3C82F6'}),
],
outputs=[
gr.Audio(
label="Output Audio",
autoplay=False,
streaming=False,
type="numpy",
),
],
allow_flagging ='never',
cache_examples=False,
title='Apollo audio restoration',
)
i.queue(max_size=20, default_concurrency_limit=4)
i.launch(share=False, server_name="0.0.0.0")
|