isLinXu commited on
Commit
c553e79
·
1 Parent(s): 9223de5
Files changed (1) hide show
  1. app.py +185 -55
app.py CHANGED
@@ -1,28 +1,27 @@
1
-
2
-
3
  import os
4
- os.system("pip install xtcocotools>=1.12")
 
5
  os.system("pip install 'mmengine>=0.6.0'")
6
  os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
7
  os.system("pip install 'mmdet>=3.0.0,<4.0.0'")
8
- os.system("pip install 'mmpose'")
 
 
 
 
9
 
10
  import PIL
11
  import cv2
12
- import mmpose
13
  import numpy as np
14
-
15
  import torch
16
- from mmpose.apis import MMPoseInferencer
17
- import gradio as gr
18
 
19
  import warnings
20
 
21
  warnings.filterwarnings("ignore")
22
 
23
- mmpose_model_list = ["human", "hand", "face", "animal", "wholebody",
24
- "vitpose", "vitpose-s", "vitpose-b", "vitpose-l", "vitpose-h"]
25
-
26
 
27
  def save_image(img, img_path):
28
  # Convert PIL image to OpenCV image
@@ -31,52 +30,183 @@ def save_image(img, img_path):
31
  cv2.imwrite(img_path, img)
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def download_test_image():
35
  # Images
36
  torch.hub.download_url_to_file(
37
- 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
38
- 'bus.jpg')
 
 
 
 
 
 
39
  torch.hub.download_url_to_file(
40
- 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
41
- 'dogs.jpg')
42
  torch.hub.download_url_to_file(
43
- 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
44
- 'zidane.jpg')
45
-
46
-
47
- def predict_pose(img, model_name, out_dir):
48
- img_path = "input_img.jpg"
49
- save_image(img, img_path)
50
- device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
51
- inferencer = MMPoseInferencer(model_name, device=device)
52
- result_generator = inferencer(img_path, show=False, out_dir=out_dir)
53
- result = next(result_generator)
54
- save_dir = './output/visualizations/'
55
- if os.path.exists(save_dir):
56
- out_img_path = save_dir + img_path
57
- print("out_img_path: ", out_img_path)
58
- else:
59
- out_img_path = img_path
60
- out_img = PIL.Image.open(out_img_path)
61
- return out_img
62
-
63
- download_test_image()
64
- input_image = gr.inputs.Image(type='pil', label="Original Image")
65
- model_name = gr.inputs.Dropdown(choices=[m for m in mmpose_model_list], label='Model')
66
- out_dir = gr.inputs.Textbox(label="Output Directory", default="./output")
67
- output_image = gr.outputs.Image(type="pil", label="Output Image")
68
-
69
- examples = [
70
- ['zidane.jpg', 'human'],
71
- ['dogs.jpg', 'animal'],
72
- ]
73
- title = "MMPose detection web demo"
74
- description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmpose/main/resources/mmpose-logo.png' width='450''/><div>" \
75
- "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmpose'>MMPose</a> MMPose 是一款基于 PyTorch 的姿态分析的开源工具箱,是 OpenMMLab 项目的成员之一。" \
76
- "OpenMMLab Pose Estimation Toolbox and Benchmark..</p>"
77
- article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmpose'>MMPose</a></p>" \
78
- "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
79
-
80
- iface = gr.Interface(fn=predict_pose, inputs=[input_image, model_name, out_dir], outputs=output_image,
81
- examples=examples, title=title, description=description, article=article)
82
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+
3
+ os.system("pip install gradio==3.42.0")
4
  os.system("pip install 'mmengine>=0.6.0'")
5
  os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
6
  os.system("pip install 'mmdet>=3.0.0,<4.0.0'")
7
+ os.system("pip install mmocr")
8
+
9
+ import json
10
+ import os
11
+ from argparse import ArgumentParser
12
 
13
  import PIL
14
  import cv2
15
+ import gradio as gr
16
  import numpy as np
 
17
  import torch
18
+ from PIL.Image import Image
19
+ from mmocr.apis.inferencers import MMOCRInferencer
20
 
21
  import warnings
22
 
23
  warnings.filterwarnings("ignore")
24
 
 
 
 
25
 
26
  def save_image(img, img_path):
27
  # Convert PIL image to OpenCV image
 
30
  cv2.imwrite(img_path, img)
31
 
32
 
33
+ textdet_model_list = ['DBNet', 'DRRG', 'FCENet', 'PANet', 'PSENet', 'TextSnake', 'MaskRCNN']
34
+ textrec_model_list = ['ABINet', 'ASTER', 'CRNN', 'MASTER', 'NRTR', 'RobustScanner', 'SARNet', 'SATRN', 'SVTR']
35
+ textkie_model_list = ['SDMGR']
36
+
37
+ def ocr_inference(inputs, out_dir, det, det_weights, rec, rec_weights, kie, kie_weights, device, batch_size):
38
+ init_args, call_args = parse_args()
39
+ inputs = np.array(inputs)
40
+ img_path = "demo_text_ocr.jpg"
41
+ save_image(inputs, img_path)
42
+ if det is not None and rec is not None:
43
+ init_args['det'] = det
44
+ init_args['det_weights'] = None
45
+ init_args['rec'] = rec
46
+ init_args['rec_weights'] = None
47
+ elif det_weights is not None and rec_weights is not None:
48
+ init_args['det'] = None
49
+ init_args['det_weights'] = det_weights
50
+ init_args['rec'] = None
51
+ init_args['rec_weights'] = rec_weights
52
+ call_args['inputs'] = img_path
53
+ call_args['out_dir'] = out_dir
54
+ call_args['batch_size'] = int(batch_size)
55
+ call_args['show'] = False
56
+ call_args['save_pred'] = True
57
+ call_args['save_vis'] = True
58
+ init_args['device'] = device
59
+ print("init_args", init_args)
60
+ print("call_args", call_args)
61
+ ocr = MMOCRInferencer(**init_args)
62
+ ocr(**call_args)
63
+ save_vis_dir = '../../results/vis/'
64
+ save_pred_dir = '../../results/preds/'
65
+ img_out = PIL.Image.open(os.path.join(save_vis_dir, img_path))
66
+ json_out = json.load(open(os.path.join(save_pred_dir, img_path.replace('.jpg', '.json'))))
67
+ return img_out, json_out
68
+
69
+
70
  def download_test_image():
71
  # Images
72
  torch.hub.download_url_to_file(
73
+ 'https://user-images.githubusercontent.com/59380685/266821429-9a897c0a-5b02-4260-a65b-3514b758f6b6.jpg',
74
+ 'demo_densetext_det.jpg')
75
+ torch.hub.download_url_to_file(
76
+ 'https://user-images.githubusercontent.com/59380685/266821432-17bb0646-a3e9-451e-9b4d-6e41ce4c3f0c.jpg',
77
+ 'demo_text_recog.jpg')
78
+ torch.hub.download_url_to_file(
79
+ 'https://user-images.githubusercontent.com/59380685/266821434-fe0d4d18-f3e2-4acf-baf5-0d2e318f0b09.jpg',
80
+ 'demo_text_ocr.jpg')
81
  torch.hub.download_url_to_file(
82
+ 'https://user-images.githubusercontent.com/59380685/266821435-5d7af2b4-cb84-4355-91cb-37d90e91aa30.jpg',
83
+ 'demo_text_det.jpg')
84
  torch.hub.download_url_to_file(
85
+ 'https://user-images.githubusercontent.com/59380685/266821436-4790c6c1-2da5-45c7-b837-04eeea0d7264.jpeg',
86
+ 'demo_kie.jpg')
87
+
88
+
89
+ def parse_args():
90
+ parser = ArgumentParser()
91
+ parser.add_argument(
92
+ '--inputs', type=str, help='Input image file or folder path.')
93
+ parser.add_argument(
94
+ '--out-dir',
95
+ type=str,
96
+ default='./results/',
97
+ help='Output directory of results.')
98
+ parser.add_argument(
99
+ '--det',
100
+ type=str,
101
+ default=None,
102
+ help='Pretrained text detection algorithm. It\'s the path to the '
103
+ 'config file or the model name defined in metafile.')
104
+ parser.add_argument(
105
+ '--det-weights',
106
+ type=str,
107
+ default=None,
108
+ help='Path to the custom checkpoint file of the selected det model. '
109
+ 'If it is not specified and "det" is a model name of metafile, the '
110
+ 'weights will be loaded from metafile.')
111
+ parser.add_argument(
112
+ '--rec',
113
+ type=str,
114
+ default=None,
115
+ help='Pretrained text recognition algorithm. It\'s the path to the '
116
+ 'config file or the model name defined in metafile.')
117
+ parser.add_argument(
118
+ '--rec-weights',
119
+ type=str,
120
+ default=None,
121
+ help='Path to the custom checkpoint file of the selected recog model. '
122
+ 'If it is not specified and "rec" is a model name of metafile, the '
123
+ 'weights will be loaded from metafile.')
124
+ parser.add_argument(
125
+ '--kie',
126
+ type=str,
127
+ default=None,
128
+ help='Pretrained key information extraction algorithm. It\'s the path'
129
+ 'to the config file or the model name defined in metafile.')
130
+ parser.add_argument(
131
+ '--kie-weights',
132
+ type=str,
133
+ default=None,
134
+ help='Path to the custom checkpoint file of the selected kie model. '
135
+ 'If it is not specified and "kie" is a model name of metafile, the '
136
+ 'weights will be loaded from metafile.')
137
+ parser.add_argument(
138
+ '--device',
139
+ type=str,
140
+ default=None,
141
+ help='Device used for inference. '
142
+ 'If not specified, the available device will be automatically used.')
143
+ parser.add_argument(
144
+ '--batch-size', type=int, default=1, help='Inference batch size.')
145
+ parser.add_argument(
146
+ '--show',
147
+ action='store_true',
148
+ help='Display the image in a popup window.')
149
+ parser.add_argument(
150
+ '--print-result',
151
+ action='store_true',
152
+ help='Whether to print the results.')
153
+ parser.add_argument(
154
+ '--save_pred',
155
+ action='store_true',
156
+ help='Save the inference results to out_dir.')
157
+ parser.add_argument(
158
+ '--save_vis',
159
+ action='store_true',
160
+ help='Save the visualization results to out_dir.')
161
+
162
+ call_args = vars(parser.parse_args())
163
+
164
+ init_kws = [
165
+ 'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights', 'device'
166
+ ]
167
+ init_args = {}
168
+ for init_kw in init_kws:
169
+ init_args[init_kw] = call_args.pop(init_kw)
170
+
171
+ return init_args, call_args
172
+
173
+
174
+ if __name__ == '__main__':
175
+ # Define Gradio input and output types
176
+ input_image = gr.inputs.Image(type="pil", label="Input Image")
177
+ out_dir = gr.inputs.Textbox(default="results")
178
+ det = gr.inputs.Dropdown(label="Text Detection Model", choices=[m for m in textdet_model_list], default='DBNet')
179
+ det_weights = gr.inputs.Textbox(default=None)
180
+ rec = gr.inputs.Dropdown(label="Text Recognition Model", choices=[m for m in textrec_model_list], default='CRNN')
181
+ rec_weights = gr.inputs.Textbox(default=None)
182
+ device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu")
183
+ batch_size = gr.inputs.Number(default=1, label="Inference batch size")
184
+ output_image = gr.outputs.Image(type="pil", label="Output Image")
185
+ output_json = gr.outputs.Textbox()
186
+ download_test_image()
187
+ examples = [["demo_text_ocr.jpg", "results", "DBNet", None, "CRNN", "cpu", 1],
188
+ ["demo_text_det.jpg", "results", "FCENet", None, "ASTER", "cpu", 1],
189
+ ["demo_text_recog.jpg", "results", "PANet", None, "MASTER", "cpu", 1],
190
+ ["demo_densetext_det.jpg", "results", "PSENet", None, "CRNN", None, "cpu", 1],
191
+ ["demo_kie.jpg", "results", "TextSnake", None, "RobustScanner", None, "cpu", 1]
192
+ ]
193
+
194
+ title = "MMOCR web demo"
195
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmocr/main/resources/mmocr-logo.png' width='450''/><div>" \
196
+ "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a> MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检测,文本识别以及相应的下游任务,如关键信息提取。 它是 OpenMMLab 项目的一部分。" \
197
+ "OpenMMLab Text Detection, Recognition and Understanding Toolbox.</p>"
198
+ article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a></p>" \
199
+ "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
200
+
201
+ # Create Gradio interface
202
+ iface = gr.Interface(
203
+ fn=ocr_inference,
204
+ inputs=[
205
+ input_image, out_dir, det, det_weights, rec, rec_weights, device, batch_size
206
+ ],
207
+ outputs=[output_image, output_json], examples=examples,
208
+ title=title, description=description, article=article,
209
+ )
210
+
211
+ # Launch Gradio interface
212
+ iface.launch()