File size: 3,634 Bytes
f91e7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr


import os
import sys
base_path = os.path.expanduser('~')

sys.path.append(os.path.join(base_path, 'Er0mangaSeg/'))
sys.path.append(os.path.join(base_path, 'Er0mangaSeg/demo'))
from image_demo_tta import init_seg_model, inference_tta

sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/'))
sys.path.append(os.path.join(base_path, 'Er0mangaInpaint/bin'))
from uncen import init_inpaint_model, inpaint


import time
import numpy as np
import cv2
import shutil
import torch


if torch.cuda.is_available():
    print('GPU found!')
    device = 'cuda:0'
else:
    print('GPU not found! Using CPU')
    device = 'cpu'

 
config = os.path.join(base_path, 'Er0mangaSeg/configs/convnext/convnext_h.py')
checkpoint = os.path.join(base_path, 'Er0mangaSeg/pretrained/convnext_1024_iter_400.pth')
model_seg = init_seg_model(config, checkpoint, device=device)
print('Segmentation initialized')


inp_model_path = os.path.join(base_path, 'Er0mangaInpaint/pretrained/00-30-09')
model_inp = init_inpaint_model(inp_model_path)
print('Inpainting initialized')


def proc(input_img):

    try:

        s = time.time()

        out_mask, raw_mask = inference_tta(model_seg, input_img)
        out_mask = np.dstack([out_mask, out_mask, out_mask])
        raw_mask = np.dstack([raw_mask, raw_mask, raw_mask])

        output_img, out_dbg = inpaint(model_inp, input_img, out_mask)

        e = time.time()
        print(f"proc_time: {e-s:.2f}")

        return output_img#, raw_mask

    except Exception as e:
        raise gr.Error(e)


def proc_batch(batch):

    res = []
    try:

        s = time.time()

        out_p = os.path.dirname(batch[0][0])
        salt = str(np.random.randint(1e10))
        out_p_d = os.path.join(out_p, '__salt_img__'+salt)
        out_p_m = os.path.join(out_p, '__salt_mask__'+salt)
        os.mkdir(out_p_d)
        os.mkdir(out_p_m)

        for i in range(len(batch)):
            input_path = batch[i][0]
            inp_name = os.path.basename(input_path)
            input_img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB)

            out_mask, raw_mask = inference_tta(model_seg, input_img)
            out_mask = np.dstack([out_mask, out_mask, out_mask])
            raw_mask = np.dstack([raw_mask, raw_mask, raw_mask])

            output_img, out_dbg = inpaint(model_inp, input_img, out_mask)
            out_path_img = os.path.join(out_p_d, inp_name)
            out_path_mask = os.path.join(out_p_m, inp_name+'.png')
            cv2.imwrite(out_path_img, cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB))
            cv2.imwrite(out_path_mask, raw_mask)
            res.append(out_path_img)

        ar_path = os.path.join(out_p, 'output')
        shutil.make_archive(ar_path, 'zip', out_p_d)

        ar_path_m = os.path.join(out_p, 'output_mask')
        shutil.make_archive(ar_path_m, 'zip', out_p_m)

        e = time.time()
        print(f"batch proc_time: {e-s:.2f}")

        return res, ar_path + '.zip', ar_path_m + '.zip'

    except Exception as e:
        raise gr.Error(e)



demo1 = gr.Interface(proc, gr.Image(), gr.Image(format='png'), delete_cache=(7200, 7200), allow_flagging='never')
demo2 = gr.Interface(proc_batch, gr.Gallery(), [gr.Gallery(value='str', format='png'), gr.File(), gr.File()], delete_cache=(7200, 7200), allow_flagging='never')
demo = gr.TabbedInterface([demo1, demo2], ["Single image processing", "Batch processing (experimental)"])

if __name__ == "__main__":
    demo.launch(server_name='0.0.0.0', server_port=7860)