File size: 8,386 Bytes
7358262
c553e79
 
7358262
 
596c792
c553e79
 
 
 
7358262
 
 
c553e79
7358262
 
c553e79
 
7358262
 
 
 
 
9223de5
7358262
 
 
 
 
 
 
c553e79
 
 
 
0eb5f90
c553e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eb5f90
c553e79
 
 
 
 
 
 
 
215f810
 
c553e79
 
 
 
 
7358262
 
 
c553e79
 
 
 
 
 
 
 
7358262
c553e79
 
7358262
c553e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0eb5f90
 
d58c45b
c553e79
 
 
 
 
 
 
 
 
 
 
 
 
0eb5f90
c553e79
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import os

os.system("pip install gradio==3.42.0")
os.system("pip install 'mmengine>=0.6.0'")
os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
os.system("pip install 'mmdet>=3.0.0rc5, < 3.2.0'")
os.system("pip install mmocr")
import json
import os
from argparse import ArgumentParser

import PIL
import cv2
import gradio as gr
import numpy as np
import torch
from PIL.Image import Image
from mmocr.apis.inferencers import MMOCRInferencer

import warnings

warnings.filterwarnings("ignore")


def save_image(img, img_path):
    # Convert PIL image to OpenCV image
    img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    # Save OpenCV image
    cv2.imwrite(img_path, img)


textdet_model_list = ['DBNet', 'DRRG', 'FCENet', 'PANet', 'PSENet', 'TextSnake', 'MaskRCNN']
textrec_model_list = ['ABINet', 'ASTER', 'CRNN', 'MASTER', 'NRTR', 'RobustScanner', 'SARNet', 'SATRN', 'SVTR']
textkie_model_list = ['SDMGR']

def ocr_inference(inputs, out_dir, det, det_weights, rec, rec_weights, kie, kie_weights, device):
    init_args, call_args = parse_args()
    inputs = np.array(inputs)
    img_path = "demo_text_ocr.jpg"
    save_image(inputs, img_path)
    if det is not None and rec is not None:
        init_args['det'] = det
        init_args['det_weights'] = None
        init_args['rec'] = rec
        init_args['rec_weights'] = None
    elif det_weights is not None and rec_weights is not None:
        init_args['det'] = None
        init_args['det_weights'] = det_weights
        init_args['rec'] = None
        init_args['rec_weights'] = rec_weights
    call_args['inputs'] = img_path
    call_args['out_dir'] = out_dir
    call_args['batch_size'] = 1
    call_args['show'] = False
    call_args['save_pred'] = True
    call_args['save_vis'] = True
    init_args['device'] = device
    print("init_args", init_args)
    print("call_args", call_args)
    ocr = MMOCRInferencer(**init_args)
    ocr(**call_args)
    save_vis_dir = './results/vis/'
    save_pred_dir = './results/preds/'
    img_out = PIL.Image.open(os.path.join(save_vis_dir, img_path))
    json_out = json.load(open(os.path.join(save_pred_dir, img_path.replace('.jpg', '.json'))))
    return img_out, json_out


def download_test_image():
    # Images
    torch.hub.download_url_to_file(
        'https://user-images.githubusercontent.com/59380685/266821429-9a897c0a-5b02-4260-a65b-3514b758f6b6.jpg',
        'demo_densetext_det.jpg')
    torch.hub.download_url_to_file(
        'https://user-images.githubusercontent.com/59380685/266821432-17bb0646-a3e9-451e-9b4d-6e41ce4c3f0c.jpg',
        'demo_text_recog.jpg')
    torch.hub.download_url_to_file(
        'https://user-images.githubusercontent.com/59380685/266821434-fe0d4d18-f3e2-4acf-baf5-0d2e318f0b09.jpg',
        'demo_text_ocr.jpg')
    torch.hub.download_url_to_file(
        'https://user-images.githubusercontent.com/59380685/266821435-5d7af2b4-cb84-4355-91cb-37d90e91aa30.jpg',
        'demo_text_det.jpg')
    torch.hub.download_url_to_file(
        'https://user-images.githubusercontent.com/59380685/266821436-4790c6c1-2da5-45c7-b837-04eeea0d7264.jpeg',
        'demo_kie.jpg')


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        '--inputs', type=str, help='Input image file or folder path.')
    parser.add_argument(
        '--out-dir',
        type=str,
        default='./results/',
        help='Output directory of results.')
    parser.add_argument(
        '--det',
        type=str,
        default=None,
        help='Pretrained text detection algorithm. It\'s the path to the '
             'config file or the model name defined in metafile.')
    parser.add_argument(
        '--det-weights',
        type=str,
        default=None,
        help='Path to the custom checkpoint file of the selected det model. '
             'If it is not specified and "det" is a model name of metafile, the '
             'weights will be loaded from metafile.')
    parser.add_argument(
        '--rec',
        type=str,
        default=None,
        help='Pretrained text recognition algorithm. It\'s the path to the '
             'config file or the model name defined in metafile.')
    parser.add_argument(
        '--rec-weights',
        type=str,
        default=None,
        help='Path to the custom checkpoint file of the selected recog model. '
             'If it is not specified and "rec" is a model name of metafile, the '
             'weights will be loaded from metafile.')
    parser.add_argument(
        '--kie',
        type=str,
        default=None,
        help='Pretrained key information extraction algorithm. It\'s the path'
             'to the config file or the model name defined in metafile.')
    parser.add_argument(
        '--kie-weights',
        type=str,
        default=None,
        help='Path to the custom checkpoint file of the selected kie model. '
             'If it is not specified and "kie" is a model name of metafile, the '
             'weights will be loaded from metafile.')
    parser.add_argument(
        '--device',
        type=str,
        default=None,
        help='Device used for inference. '
             'If not specified, the available device will be automatically used.')
    parser.add_argument(
        '--batch-size', type=int, default=1, help='Inference batch size.')
    parser.add_argument(
        '--show',
        action='store_true',
        help='Display the image in a popup window.')
    parser.add_argument(
        '--print-result',
        action='store_true',
        help='Whether to print the results.')
    parser.add_argument(
        '--save_pred',
        action='store_true',
        help='Save the inference results to out_dir.')
    parser.add_argument(
        '--save_vis',
        action='store_true',
        help='Save the visualization results to out_dir.')

    call_args = vars(parser.parse_args())

    init_kws = [
        'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights', 'device'
    ]
    init_args = {}
    for init_kw in init_kws:
        init_args[init_kw] = call_args.pop(init_kw)

    return init_args, call_args


if __name__ == '__main__':
    # Define Gradio input and output types
    input_image = gr.inputs.Image(type="pil", label="Input Image")
    out_dir = gr.inputs.Textbox(default="results")
    det = gr.inputs.Dropdown(label="Text Detection Model", choices=[m for m in textdet_model_list], default='DBNet')
    det_weights = gr.inputs.Textbox(default=None)
    rec = gr.inputs.Dropdown(label="Text Recognition Model", choices=[m for m in textrec_model_list], default='CRNN')
    rec_weights = gr.inputs.Textbox(default=None)
    device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu")
    batch_size = gr.inputs.Number(default=1, label="Inference batch size")
    output_image = gr.outputs.Image(type="pil", label="Output Image")
    output_json = gr.outputs.Textbox()
    download_test_image()
    examples = [["demo_text_ocr.jpg", "results", "DBNet", None, "CRNN", "cpu"],
                ["demo_text_det.jpg", "results", "FCENet", None, "ASTER", "cpu"],
                ["demo_text_recog.jpg", "results", "FCENet", None, "MASTER", "cpu"],
                ]

    title = "MMOCR web demo"
    description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmocr/main/resources/mmocr-logo.png' width='450''/><div>" \
                  "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a> MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检测,文本识别以及相应的下游任务,如关键信息提取。 它是 OpenMMLab 项目的一部分。" \
                  "OpenMMLab Text Detection, Recognition and Understanding Toolbox.</p>"
    article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a></p>" \
              "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"

    # Create Gradio interface
    iface = gr.Interface(
        fn=ocr_inference,
        inputs=[
            input_image, out_dir, det, det_weights, rec, rec_weights, device
        ],
        outputs=[output_image, output_json], examples=examples,
        title=title, description=description, article=article,
    )

    # Launch Gradio interface
    iface.launch()