File size: 5,277 Bytes
78e32cc
 
 
 
 
ea8d6db
72b5eb1
78e32cc
 
 
72b5eb1
98a6a49
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701d19a
78e32cc
 
 
 
e42e488
701d19a
98a6a49
e42e488
 
701d19a
a60a79a
 
78e32cc
1b8633f
a60a79a
 
701d19a
 
1b8633f
 
 
a60a79a
 
701d19a
 
78e32cc
 
 
1b8633f
 
 
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98a6a49
78e32cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a161d55
 
78e32cc
 
 
 
 
 
 
 
 
 
 
 
701d19a
78e32cc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import torch
import numpy as np
import gradio as gr
import yaml
import librosa
from tqdm.auto import tqdm
import spaces

import look2hear.models
from ml_collections import ConfigDict

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def load_audio(file_path):
    audio, samplerate = librosa.load(file_path, mono=False, sr=44100)
    print(f'INPUT audio.shape = {audio.shape} | samplerate = {samplerate}')
    #audio = dBgain(audio, -6)
    return torch.from_numpy(audio), samplerate


def get_config(config_path):
    with open(config_path) as f:
        #config = OmegaConf.load(config_path)
        config = ConfigDict(yaml.load(f, Loader=yaml.FullLoader))
        return config




def _getWindowingArray(window_size, fade_size):
    # IMPORTANT NOTE :
    # no fades here in the end, only removing the failed ending of the chunk
    fadein = torch.linspace(1, 1, fade_size)
    fadeout = torch.linspace(0, 0, fade_size)
    window = torch.ones(window_size)
    window[-fade_size:] *= fadeout
    window[:fade_size] *= fadein
    return window



description = f'''
This is unofficial space for audio restoration model Apollo: https://github.com/JusperLee/Apollo
'''


apollo_config = get_config('configs/apollo.yaml')
apollo_vocal2_config = get_config('configs/config_apollo_vocal.yaml')
apollo_uni_config = get_config('configs/config_apollo_uni.yaml')
apollo_model = look2hear.models.BaseModel.from_pretrain('weights/apollo.bin', **apollo_config['model']).to(device)
apollo_vocal = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal.bin', **apollo_config['model']).to(device)
apollo_vocal2 = look2hear.models.BaseModel.from_pretrain('weights/apollo_vocal2.bin', **apollo_vocal2_config['model']).to(device)
apollo_uni = look2hear.models.BaseModel.from_pretrain('weights/apollo_model_uni.ckpt', **apollo_uni_config['model']).to(device)



models = {
   'apollo': apollo_model,
   'apollo_vocal': apollo_vocal,
   'apollo_vocal2': apollo_vocal2,
   'apollo_uni': apollo_uni
}

choices = [
    ('MP3 restore', 'apollo'),
    ('Apollo vocal', 'apollo_vocal'),
    ('Apollo vocal2', 'apollo_vocal2'),
    ('Apollo universal', 'apollo_uni')
]

@spaces.GPU
def enchance(choice, audio):
    print(choice)
    model = models[choice]
    test_data, samplerate = load_audio(audio)
    C = 10 * samplerate  # chunk_size seconds to samples
    N = 2
    step = C // N
    fade_size = 3 * 44100 # 3 seconds
    print(f"N = {N} | C = {C} | step = {step} | fade_size = {fade_size}")
    
    border = C - step
    
    # handle mono inputs correctly
    if len(test_data.shape) == 1:
        test_data = test_data.unsqueeze(0) 

    # Pad the input if necessary
    if test_data.shape[1] > 2 * border and (border > 0):
        test_data = torch.nn.functional.pad(test_data, (border, border), mode='reflect')

    windowingArray = _getWindowingArray(C, fade_size)

    result = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)
    counter = torch.zeros((1,) + tuple(test_data.shape), dtype=torch.float32)

    i = 0
    progress_bar = tqdm(total=test_data.shape[1], desc="Processing audio chunks", leave=False)

    while i < test_data.shape[1]:
        part = test_data[:, i:i + C]
        length = part.shape[-1]
        if length < C:
            if length > C // 2 + 1:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length), mode='reflect')
            else:
                part = torch.nn.functional.pad(input=part, pad=(0, C - length, 0, 0), mode='constant', value=0)


        chunk = part.unsqueeze(0).to(device)
        with torch.no_grad():
            out = model(chunk).squeeze(0).squeeze(0).cpu()

        window = windowingArray
        if i == 0:  # First audio chunk, no fadein
            window[:fade_size] = 1
        elif i + C >= test_data.shape[1]:  # Last audio chunk, no fadeout
            window[-fade_size:] = 1

        result[..., i:i+length] += out[..., :length] * window[..., :length]
        counter[..., i:i+length] += window[..., :length]

        i += step
        progress_bar.update(step)

    progress_bar.close()

    final_output = result / counter
    final_output = final_output.squeeze(0).numpy()
    np.nan_to_num(final_output, copy=False, nan=0.0)

    # Remove padding if added earlier
    if test_data.shape[1] > 2 * border and (border > 0):
        final_output = final_output[..., border:-border]
    
    return samplerate, final_output.T


if __name__ == "__main__":
    i = gr.Interface(
        fn=enchance,
        description=description,
        inputs=[
            gr.Dropdown(label="Model", choices=choices, value=choices[0][1]),
            gr.Audio(label="Input Audio:", interactive=True, type='filepath', max_length=3000, waveform_options={'waveform_progress_color': '#3C82F6'}),
        ],
        outputs=[
            gr.Audio(
                        label="Output Audio",
                        autoplay=False,
                        streaming=False,
                        type="numpy",
                    ),
            
        ],
        allow_flagging ='never',
        cache_examples=False,
        title='Apollo audio restoration',
        
    )
    i.queue(max_size=20, default_concurrency_limit=4)
    i.launch(share=False, server_name="0.0.0.0")