isLinXu commited on
Commit
9223de5
·
1 Parent(s): 6c89ef7
Files changed (1) hide show
  1. app.py +55 -190
app.py CHANGED
@@ -1,27 +1,29 @@
1
 
 
2
  import os
3
- os.system("pip install gradio==3.42.0")
4
  os.system("pip install 'mmengine>=0.6.0'")
5
  os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
6
  os.system("pip install 'mmdet>=3.0.0,<4.0.0'")
7
- os.system("pip install mmocr")
8
-
9
- import json
10
- import os
11
- from argparse import ArgumentParser
12
 
13
  import PIL
14
  import cv2
15
- import gradio as gr
16
  import numpy as np
 
17
  import torch
18
- from PIL.Image import Image
19
- from mmocr.apis.inferencers import MMOCRInferencer
20
 
21
  import warnings
22
 
23
  warnings.filterwarnings("ignore")
24
 
 
 
 
 
25
  def save_image(img, img_path):
26
  # Convert PIL image to OpenCV image
27
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
@@ -29,189 +31,52 @@ def save_image(img, img_path):
29
  cv2.imwrite(img_path, img)
30
 
31
 
32
- textdet_model_list = ['DBNet', 'DRRG', 'FCENet', 'PANet', 'PSENet', 'TextSnake', 'MaskRCNN']
33
- textrec_model_list = ['ABINet', 'ASTER', 'CRNN', 'MASTER', 'NRTR', 'RobustScanner', 'SARNet', 'SATRN', 'SVTR']
34
- textkie_model_list = ['SDMGR','SDMGR']
35
-
36
-
37
- def ocr_inference(inputs, out_dir, det, det_weights, rec, rec_weights, device):
38
- init_args, call_args = parse_args()
39
- inputs = np.array(inputs)
40
- img_path = "demo_text_ocr.jpg"
41
- save_image(inputs, img_path)
42
- if det is not None and rec is not None:
43
- init_args['det'] = det
44
- init_args['det_weights'] = None
45
- init_args['rec'] = rec
46
- init_args['rec_weights'] = None
47
- elif det_weights is not None and rec_weights is not None:
48
- init_args['det'] = None
49
- init_args['det_weights'] = det_weights
50
- init_args['rec'] = None
51
- init_args['rec_weights'] = rec_weights
52
-
53
- call_args['inputs'] = img_path
54
- call_args['out_dir'] = out_dir
55
- call_args['batch_size'] = 1
56
- call_args['show'] = False
57
- call_args['save_pred'] = True
58
- call_args['save_vis'] = True
59
- init_args['device'] = device
60
- print("init_args", init_args)
61
- print("call_args", call_args)
62
- ocr = MMOCRInferencer(**init_args)
63
- ocr(**call_args)
64
- save_vis_dir = './results/vis/'
65
- save_pred_dir = './results/preds/'
66
- img_out = PIL.Image.open(os.path.join(save_vis_dir, img_path))
67
- json_out = json.load(open(os.path.join(save_pred_dir, img_path.replace('.jpg', '.json'))))
68
- return img_out, json_out
69
-
70
-
71
  def download_test_image():
72
  # Images
73
  torch.hub.download_url_to_file(
74
- 'https://user-images.githubusercontent.com/59380685/266821429-9a897c0a-5b02-4260-a65b-3514b758f6b6.jpg',
75
- 'demo_densetext_det.jpg')
76
- torch.hub.download_url_to_file(
77
- 'https://user-images.githubusercontent.com/59380685/266821432-17bb0646-a3e9-451e-9b4d-6e41ce4c3f0c.jpg',
78
- 'demo_text_recog.jpg')
79
- torch.hub.download_url_to_file(
80
- 'https://user-images.githubusercontent.com/59380685/266821434-fe0d4d18-f3e2-4acf-baf5-0d2e318f0b09.jpg',
81
- 'demo_text_ocr.jpg')
82
  torch.hub.download_url_to_file(
83
- 'https://user-images.githubusercontent.com/59380685/266821435-5d7af2b4-cb84-4355-91cb-37d90e91aa30.jpg',
84
- 'demo_text_det.jpg')
85
  torch.hub.download_url_to_file(
86
- 'https://user-images.githubusercontent.com/59380685/266821436-4790c6c1-2da5-45c7-b837-04eeea0d7264.jpeg',
87
- 'demo_kie.jpg')
88
-
89
-
90
- def parse_args():
91
- parser = ArgumentParser()
92
- parser.add_argument(
93
- '--inputs', type=str, help='Input image file or folder path.')
94
- parser.add_argument(
95
- '--out-dir',
96
- type=str,
97
- default='./results/',
98
- help='Output directory of results.')
99
- parser.add_argument(
100
- '--det',
101
- type=str,
102
- default=None,
103
- help='Pretrained text detection algorithm. It\'s the path to the '
104
- 'config file or the model name defined in metafile.')
105
- parser.add_argument(
106
- '--det-weights',
107
- type=str,
108
- default=None,
109
- help='Path to the custom checkpoint file of the selected det model. '
110
- 'If it is not specified and "det" is a model name of metafile, the '
111
- 'weights will be loaded from metafile.')
112
- parser.add_argument(
113
- '--rec',
114
- type=str,
115
- default=None,
116
- help='Pretrained text recognition algorithm. It\'s the path to the '
117
- 'config file or the model name defined in metafile.')
118
- parser.add_argument(
119
- '--rec-weights',
120
- type=str,
121
- default=None,
122
- help='Path to the custom checkpoint file of the selected recog model. '
123
- 'If it is not specified and "rec" is a model name of metafile, the '
124
- 'weights will be loaded from metafile.')
125
- parser.add_argument(
126
- '--kie',
127
- type=str,
128
- default=None,
129
- help='Pretrained key information extraction algorithm. It\'s the path'
130
- 'to the config file or the model name defined in metafile.')
131
- parser.add_argument(
132
- '--kie-weights',
133
- type=str,
134
- default=None,
135
- help='Path to the custom checkpoint file of the selected kie model. '
136
- 'If it is not specified and "kie" is a model name of metafile, the '
137
- 'weights will be loaded from metafile.')
138
- parser.add_argument(
139
- '--device',
140
- type=str,
141
- default=None,
142
- help='Device used for inference. '
143
- 'If not specified, the available device will be automatically used.')
144
- parser.add_argument(
145
- '--batch-size', type=int, default=1, help='Inference batch size.')
146
- parser.add_argument(
147
- '--show',
148
- action='store_true',
149
- help='Display the image in a popup window.')
150
- parser.add_argument(
151
- '--print-result',
152
- action='store_true',
153
- help='Whether to print the results.')
154
- parser.add_argument(
155
- '--save_pred',
156
- action='store_true',
157
- help='Save the inference results to out_dir.')
158
- parser.add_argument(
159
- '--save_vis',
160
- action='store_true',
161
- help='Save the visualization results to out_dir.')
162
-
163
- call_args = vars(parser.parse_args())
164
-
165
- init_kws = [
166
- 'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights', 'device'
167
- ]
168
- init_args = {}
169
- for init_kw in init_kws:
170
- init_args[init_kw] = call_args.pop(init_kw)
171
-
172
- return init_args, call_args
173
-
174
-
175
- if __name__ == '__main__':
176
- # Define Gradio input and output types
177
- input_image = gr.inputs.Image(type="pil", label="Input Image")
178
- out_dir = gr.inputs.Textbox(default="results")
179
- det = gr.inputs.Dropdown(label="Text Detection Model", choices=[m for m in textdet_model_list], default='DBNet')
180
- det_weights = gr.inputs.Textbox(default=None)
181
- rec = gr.inputs.Dropdown(label="Text Recognition Model", choices=[m for m in textrec_model_list], default='CRNN')
182
- rec_weights = gr.inputs.Textbox(default=None)
183
- kie = gr.inputs.Textbox(default='SDMGR')
184
- # kie = gr.inputs.Dropdown(label="Key Information Extraction Model", choices=[m for m in textkie_model_list],
185
- # default='SDMGR')
186
- # kie_weights = gr.inputs.Textbox(default=None)
187
- device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu")
188
- batch_size = gr.inputs.Number(default=1, label="Inference batch size")
189
- output_image = gr.outputs.Image(type="pil", label="Output Image")
190
- output_json = gr.outputs.Textbox()
191
- download_test_image()
192
- examples = [["demo_text_ocr.jpg", "results", "DBNet", None, "CRNN", None, "cpu"],
193
- ["demo_text_det.jpg", "results", "FCENet", None, "ASTER", None, "cpu"],
194
- ["demo_text_recog.jpg", "results", "DBNet", None, "MASTER", None, "cpu"],
195
- ["demo_densetext_det.jpg", "results", "PSENet", None, "CRNN", None, "cpu"],
196
- ["demo_kie.jpg", "results", "TextSnake", None, "RobustScanner", None, "cpu"]
197
- ]
198
-
199
- title = "MMOCR web demo"
200
- description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmocr/main/resources/mmocr-logo.png' width='450''/><div>" \
201
- "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a> MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检测,文本识别以及相应的下游任务,如关键信息提取。 它是 OpenMMLab 项目的一部分。" \
202
- "OpenMMLab Text Detection, Recognition and Understanding Toolbox.</p>"
203
- article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a></p>" \
204
- "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
205
-
206
- # Create Gradio interface
207
- iface = gr.Interface(
208
- fn=ocr_inference,
209
- inputs=[
210
- input_image, out_dir, det, det_weights, rec, rec_weights, device, batch_size
211
- ],
212
- outputs=[output_image, output_json], examples=examples,
213
- title=title, description=description, article=article,
214
- )
215
-
216
- # Launch Gradio interface
217
- iface.launch()
 
1
 
2
+
3
  import os
4
+ os.system("pip install xtcocotools>=1.12")
5
  os.system("pip install 'mmengine>=0.6.0'")
6
  os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
7
  os.system("pip install 'mmdet>=3.0.0,<4.0.0'")
8
+ os.system("pip install 'mmpose'")
 
 
 
 
9
 
10
  import PIL
11
  import cv2
12
+ import mmpose
13
  import numpy as np
14
+
15
  import torch
16
+ from mmpose.apis import MMPoseInferencer
17
+ import gradio as gr
18
 
19
  import warnings
20
 
21
  warnings.filterwarnings("ignore")
22
 
23
+ mmpose_model_list = ["human", "hand", "face", "animal", "wholebody",
24
+ "vitpose", "vitpose-s", "vitpose-b", "vitpose-l", "vitpose-h"]
25
+
26
+
27
  def save_image(img, img_path):
28
  # Convert PIL image to OpenCV image
29
  img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
 
31
  cv2.imwrite(img_path, img)
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def download_test_image():
35
  # Images
36
  torch.hub.download_url_to_file(
37
+ 'https://user-images.githubusercontent.com/59380685/266264420-21575a83-4057-41cf-8a4a-b3ea6f332d79.jpg',
38
+ 'bus.jpg')
 
 
 
 
 
 
39
  torch.hub.download_url_to_file(
40
+ 'https://user-images.githubusercontent.com/59380685/266264536-82afdf58-6b9a-4568-b9df-551ee72cb6d9.jpg',
41
+ 'dogs.jpg')
42
  torch.hub.download_url_to_file(
43
+ 'https://user-images.githubusercontent.com/59380685/266264600-9d0c26ca-8ba6-45f2-b53b-4dc98460c43e.jpg',
44
+ 'zidane.jpg')
45
+
46
+
47
+ def predict_pose(img, model_name, out_dir):
48
+ img_path = "input_img.jpg"
49
+ save_image(img, img_path)
50
+ device = torch.cuda.current_device() if torch.cuda.is_available() else 'cpu'
51
+ inferencer = MMPoseInferencer(model_name, device=device)
52
+ result_generator = inferencer(img_path, show=False, out_dir=out_dir)
53
+ result = next(result_generator)
54
+ save_dir = './output/visualizations/'
55
+ if os.path.exists(save_dir):
56
+ out_img_path = save_dir + img_path
57
+ print("out_img_path: ", out_img_path)
58
+ else:
59
+ out_img_path = img_path
60
+ out_img = PIL.Image.open(out_img_path)
61
+ return out_img
62
+
63
+ download_test_image()
64
+ input_image = gr.inputs.Image(type='pil', label="Original Image")
65
+ model_name = gr.inputs.Dropdown(choices=[m for m in mmpose_model_list], label='Model')
66
+ out_dir = gr.inputs.Textbox(label="Output Directory", default="./output")
67
+ output_image = gr.outputs.Image(type="pil", label="Output Image")
68
+
69
+ examples = [
70
+ ['zidane.jpg', 'human'],
71
+ ['dogs.jpg', 'animal'],
72
+ ]
73
+ title = "MMPose detection web demo"
74
+ description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmpose/main/resources/mmpose-logo.png' width='450''/><div>" \
75
+ "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmpose'>MMPose</a> MMPose 是一款基于 PyTorch 的姿态分析的开源工具箱,是 OpenMMLab 项目的成员之一。" \
76
+ "OpenMMLab Pose Estimation Toolbox and Benchmark..</p>"
77
+ article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmpose'>MMPose</a></p>" \
78
+ "<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
79
+
80
+ iface = gr.Interface(fn=predict_pose, inputs=[input_image, model_name, out_dir], outputs=output_image,
81
+ examples=examples, title=title, description=description, article=article)
82
+ iface.launch()