isLinXu commited on
Commit
7358262
1 Parent(s): e50891a

update app.py

Browse files
Files changed (2) hide show
  1. app.py +223 -0
  2. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ os.system("pip install gradio==3.42.0")
4
+ os.system("pip install 'mmengine>=0.6.0'")
5
+ os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
6
+ os.system("pip install 'mmdet>=3.0.0,<4.0.0'")
7
+ os.system("pip install mmocr")
8
+
9
+ import json
10
+ import os
11
+ from argparse import ArgumentParser
12
+
13
+ import PIL
14
+ import cv2
15
+ import gradio as gr
16
+ import numpy as np
17
+ import torch
18
+ from PIL.Image import Image
19
+ from mmocr.apis.inferencers import MMOCRInferencer
20
+
21
+ import warnings
22
+
23
+ warnings.filterwarnings("ignore")
24
+
25
+ def save_image(img, img_path):
26
+ # Convert PIL image to OpenCV image
27
+ img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
28
+ # Save OpenCV image
29
+ cv2.imwrite(img_path, img)
30
+
31
+
32
+ textdet_model_list = ['DBNet', 'DRRG', 'FCENet', 'PANet', 'PSENet', 'TextSnake', 'MaskRCNN']
33
+ textrec_model_list = ['ABINet', 'ASTER', 'CRNN', 'MASTER', 'NRTR', 'RobustScanner', 'SARNet', 'SATRN', 'SVTR']
34
+ textkie_model_list = ['SDMGR','SDMGR']
35
+
36
+
37
+ def ocr_inference(inputs, out_dir, det, det_weights, rec, rec_weights, kie, kie_weights, device, batch_size):
38
+ init_args, call_args = parse_args()
39
+ inputs = np.array(inputs)
40
+ img_path = "demo_text_ocr.jpg"
41
+ save_image(inputs, img_path)
42
+ if det is not None and rec is not None:
43
+ init_args['det'] = det
44
+ init_args['det_weights'] = None
45
+ init_args['rec'] = rec
46
+ init_args['rec_weights'] = None
47
+ elif det_weights is not None and rec_weights is not None:
48
+ init_args['det'] = None
49
+ init_args['det_weights'] = det_weights
50
+ init_args['rec'] = None
51
+ init_args['rec_weights'] = rec_weights
52
+ if kie is not None:
53
+ init_args['kie'] = kie
54
+ init_args['kie_weights'] = None
55
+ if kie_weights is not None:
56
+ init_args['kie'] = None
57
+ init_args['kie_weights'] = kie_weights
58
+
59
+ call_args['inputs'] = img_path
60
+ call_args['out_dir'] = out_dir
61
+ call_args['batch_size'] = int(batch_size)
62
+ call_args['show'] = False
63
+ call_args['save_pred'] = True
64
+ call_args['save_vis'] = True
65
+ init_args['device'] = device
66
+ print("init_args", init_args)
67
+ print("call_args", call_args)
68
+ ocr = MMOCRInferencer(**init_args)
69
+ ocr(**call_args)
70
+ save_vis_dir = './results/vis/'
71
+ save_pred_dir = './results/preds/'
72
+ img_out = PIL.Image.open(os.path.join(save_vis_dir, img_path))
73
+ json_out = json.load(open(os.path.join(save_pred_dir, img_path.replace('.jpg', '.json'))))
74
+ return img_out, json_out
75
+
76
+
77
+ def download_test_image():
78
+ # Images
79
+ torch.hub.download_url_to_file(
80
+ 'https://user-images.githubusercontent.com/59380685/266821429-9a897c0a-5b02-4260-a65b-3514b758f6b6.jpg',
81
+ 'demo_densetext_det.jpg')
82
+ torch.hub.download_url_to_file(
83
+ 'https://user-images.githubusercontent.com/59380685/266821432-17bb0646-a3e9-451e-9b4d-6e41ce4c3f0c.jpg',
84
+ 'demo_text_recog.jpg')
85
+ torch.hub.download_url_to_file(
86
+ 'https://user-images.githubusercontent.com/59380685/266821434-fe0d4d18-f3e2-4acf-baf5-0d2e318f0b09.jpg',
87
+ 'demo_text_ocr.jpg')
88
+ torch.hub.download_url_to_file(
89
+ 'https://user-images.githubusercontent.com/59380685/266821435-5d7af2b4-cb84-4355-91cb-37d90e91aa30.jpg',
90
+ 'demo_text_det.jpg')
91
+ torch.hub.download_url_to_file(
92
+ 'https://user-images.githubusercontent.com/59380685/266821436-4790c6c1-2da5-45c7-b837-04eeea0d7264.jpeg',
93
+ 'demo_kie.jpg')
94
+
95
+
96
+ def parse_args():
97
+ parser = ArgumentParser()
98
+ parser.add_argument(
99
+ '--inputs', type=str, help='Input image file or folder path.')
100
+ parser.add_argument(
101
+ '--out-dir',
102
+ type=str,
103
+ default='./results/',
104
+ help='Output directory of results.')
105
+ parser.add_argument(
106
+ '--det',
107
+ type=str,
108
+ default=None,
109
+ help='Pretrained text detection algorithm. It\'s the path to the '
110
+ 'config file or the model name defined in metafile.')
111
+ parser.add_argument(
112
+ '--det-weights',
113
+ type=str,
114
+ default=None,
115
+ help='Path to the custom checkpoint file of the selected det model. '
116
+ 'If it is not specified and "det" is a model name of metafile, the '
117
+ 'weights will be loaded from metafile.')
118
+ parser.add_argument(
119
+ '--rec',
120
+ type=str,
121
+ default=None,
122
+ help='Pretrained text recognition algorithm. It\'s the path to the '
123
+ 'config file or the model name defined in metafile.')
124
+ parser.add_argument(
125
+ '--rec-weights',
126
+ type=str,
127
+ default=None,
128
+ help='Path to the custom checkpoint file of the selected recog model. '
129
+ 'If it is not specified and "rec" is a model name of metafile, the '
130
+ 'weights will be loaded from metafile.')
131
+ parser.add_argument(
132
+ '--kie',
133
+ type=str,
134
+ default=None,
135
+ help='Pretrained key information extraction algorithm. It\'s the path'
136
+ 'to the config file or the model name defined in metafile.')
137
+ parser.add_argument(
138
+ '--kie-weights',
139
+ type=str,
140
+ default=None,
141
+ help='Path to the custom checkpoint file of the selected kie model. '
142
+ 'If it is not specified and "kie" is a model name of metafile, the '
143
+ 'weights will be loaded from metafile.')
144
+ parser.add_argument(
145
+ '--device',
146
+ type=str,
147
+ default=None,
148
+ help='Device used for inference. '
149
+ 'If not specified, the available device will be automatically used.')
150
+ parser.add_argument(
151
+ '--batch-size', type=int, default=1, help='Inference batch size.')
152
+ parser.add_argument(
153
+ '--show',
154
+ action='store_true',
155
+ help='Display the image in a popup window.')
156
+ parser.add_argument(
157
+ '--print-result',
158
+ action='store_true',
159
+ help='Whether to print the results.')
160
+ parser.add_argument(
161
+ '--save_pred',
162
+ action='store_true',
163
+ help='Save the inference results to out_dir.')
164
+ parser.add_argument(
165
+ '--save_vis',
166
+ action='store_true',
167
+ help='Save the visualization results to out_dir.')
168
+
169
+ call_args = vars(parser.parse_args())
170
+
171
+ init_kws = [
172
+ 'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights', 'device'
173
+ ]
174
+ init_args = {}
175
+ for init_kw in init_kws:
176
+ init_args[init_kw] = call_args.pop(init_kw)
177
+
178
+ return init_args, call_args
179
+
180
+
181
+ if __name__ == '__main__':
182
+ # Define Gradio input and output types
183
+ input_image = gr.inputs.Image(type="pil", label="Input Image")
184
+ out_dir = gr.inputs.Textbox(default="results")
185
+ det = gr.inputs.Dropdown(label="Text Detection Model", choices=[m for m in textdet_model_list], default='DBNet')
186
+ det_weights = gr.inputs.Textbox(default=None)
187
+ rec = gr.inputs.Dropdown(label="Text Recognition Model", choices=[m for m in textrec_model_list], default='CRNN')
188
+ rec_weights = gr.inputs.Textbox(default=None)
189
+ kie = gr.inputs.Dropdown(label="Key Information Extraction Model", choices=[m for m in textkie_model_list],
190
+ default='SDMGR')
191
+ kie_weights = gr.inputs.Textbox(default=None)
192
+ device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu")
193
+ batch_size = gr.inputs.Number(default=1, label="Inference batch size")
194
+ output_image = gr.outputs.Image(type="pil", label="Output Image")
195
+ output_json = gr.outputs.Textbox()
196
+ download_test_image()
197
+ examples = [["demo_text_ocr.jpg", "results", "DBNet", None, "CRNN", None, "SDMGR", None, "cpu", 1],
198
+ ["demo_text_det.jpg", "results", "FCENet", None, "ASTER", None, "SDMGR", None, "cpu", 1],
199
+ ["demo_text_recog.jpg", "results", "PANet", None, "MASTER", None, "SDMGR", None, "cpu", 1],
200
+ ["demo_densetext_det.jpg", "results", "PSENet", None, "CRNN", None, "SDMGR", None, "cpu", 1],
201
+ ["demo_kie.jpg", "results", "TextSnake", None, "RobustScanner", None, "SDMGR", None, "cpu", 1]
202
+ ]
203
+
204
+ title = "MMOCR web demo"
205
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmocr/main/resources/mmocr-logo.png' width='450''/><div>" \
206
+ "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a> MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检测,文本识别以及相应的下游任务,如关键信息提取。 它是 OpenMMLab 项目的一部分。" \
207
+ "OpenMMLab Text Detection, Recognition and Understanding Toolbox.</p>"
208
+ article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a></p>" \
209
+ "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
210
+
211
+ # Create Gradio interface
212
+ iface = gr.Interface(
213
+ fn=ocr_inference,
214
+ inputs=[
215
+ input_image, out_dir, det, det_weights, rec, rec_weights,
216
+ kie, kie_weights, device, batch_size
217
+ ],
218
+ outputs=[output_image, output_json], examples=examples,
219
+ title=title, description=description, article=article,
220
+ )
221
+
222
+ # Launch Gradio interface
223
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wget~=3.2
2
+ opencv-python~=4.6.0.66
3
+ numpy~=1.23.0
4
+ torch~=1.13.1
5
+ torchvision~=0.14.1
6
+ pillow~=9.4.0
7
+ gradio~=3.42.0
8
+ ultralytics~=8.0.169
9
+ pyyaml~=6.0
10
+ wandb~=0.13.11
11
+ tqdm~=4.65.0
12
+ matplotlib~=3.7.1
13
+ pandas~=2.0.0
14
+ seaborn~=0.12.2
15
+ requests~=2.31.0
16
+ psutil~=5.9.4
17
+ thop~=0.1.1-2209072238
18
+ timm~=0.9.2
19
+ super-gradients~=3.2.0
20
+ openmim