Spaces:
Runtime error
Runtime error
File size: 8,386 Bytes
7358262 c553e79 7358262 596c792 c553e79 7358262 c553e79 7358262 c553e79 7358262 9223de5 7358262 c553e79 0eb5f90 c553e79 0eb5f90 c553e79 215f810 c553e79 7358262 c553e79 7358262 c553e79 7358262 c553e79 0eb5f90 d58c45b c553e79 0eb5f90 c553e79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import os
os.system("pip install gradio==3.42.0")
os.system("pip install 'mmengine>=0.6.0'")
os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'")
os.system("pip install 'mmdet>=3.0.0rc5, < 3.2.0'")
os.system("pip install mmocr")
import json
import os
from argparse import ArgumentParser
import PIL
import cv2
import gradio as gr
import numpy as np
import torch
from PIL.Image import Image
from mmocr.apis.inferencers import MMOCRInferencer
import warnings
warnings.filterwarnings("ignore")
def save_image(img, img_path):
# Convert PIL image to OpenCV image
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
# Save OpenCV image
cv2.imwrite(img_path, img)
textdet_model_list = ['DBNet', 'DRRG', 'FCENet', 'PANet', 'PSENet', 'TextSnake', 'MaskRCNN']
textrec_model_list = ['ABINet', 'ASTER', 'CRNN', 'MASTER', 'NRTR', 'RobustScanner', 'SARNet', 'SATRN', 'SVTR']
textkie_model_list = ['SDMGR']
def ocr_inference(inputs, out_dir, det, det_weights, rec, rec_weights, kie, kie_weights, device):
init_args, call_args = parse_args()
inputs = np.array(inputs)
img_path = "demo_text_ocr.jpg"
save_image(inputs, img_path)
if det is not None and rec is not None:
init_args['det'] = det
init_args['det_weights'] = None
init_args['rec'] = rec
init_args['rec_weights'] = None
elif det_weights is not None and rec_weights is not None:
init_args['det'] = None
init_args['det_weights'] = det_weights
init_args['rec'] = None
init_args['rec_weights'] = rec_weights
call_args['inputs'] = img_path
call_args['out_dir'] = out_dir
call_args['batch_size'] = 1
call_args['show'] = False
call_args['save_pred'] = True
call_args['save_vis'] = True
init_args['device'] = device
print("init_args", init_args)
print("call_args", call_args)
ocr = MMOCRInferencer(**init_args)
ocr(**call_args)
save_vis_dir = './results/vis/'
save_pred_dir = './results/preds/'
img_out = PIL.Image.open(os.path.join(save_vis_dir, img_path))
json_out = json.load(open(os.path.join(save_pred_dir, img_path.replace('.jpg', '.json'))))
return img_out, json_out
def download_test_image():
# Images
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266821429-9a897c0a-5b02-4260-a65b-3514b758f6b6.jpg',
'demo_densetext_det.jpg')
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266821432-17bb0646-a3e9-451e-9b4d-6e41ce4c3f0c.jpg',
'demo_text_recog.jpg')
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266821434-fe0d4d18-f3e2-4acf-baf5-0d2e318f0b09.jpg',
'demo_text_ocr.jpg')
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266821435-5d7af2b4-cb84-4355-91cb-37d90e91aa30.jpg',
'demo_text_det.jpg')
torch.hub.download_url_to_file(
'https://user-images.githubusercontent.com/59380685/266821436-4790c6c1-2da5-45c7-b837-04eeea0d7264.jpeg',
'demo_kie.jpg')
def parse_args():
parser = ArgumentParser()
parser.add_argument(
'--inputs', type=str, help='Input image file or folder path.')
parser.add_argument(
'--out-dir',
type=str,
default='./results/',
help='Output directory of results.')
parser.add_argument(
'--det',
type=str,
default=None,
help='Pretrained text detection algorithm. It\'s the path to the '
'config file or the model name defined in metafile.')
parser.add_argument(
'--det-weights',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected det model. '
'If it is not specified and "det" is a model name of metafile, the '
'weights will be loaded from metafile.')
parser.add_argument(
'--rec',
type=str,
default=None,
help='Pretrained text recognition algorithm. It\'s the path to the '
'config file or the model name defined in metafile.')
parser.add_argument(
'--rec-weights',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected recog model. '
'If it is not specified and "rec" is a model name of metafile, the '
'weights will be loaded from metafile.')
parser.add_argument(
'--kie',
type=str,
default=None,
help='Pretrained key information extraction algorithm. It\'s the path'
'to the config file or the model name defined in metafile.')
parser.add_argument(
'--kie-weights',
type=str,
default=None,
help='Path to the custom checkpoint file of the selected kie model. '
'If it is not specified and "kie" is a model name of metafile, the '
'weights will be loaded from metafile.')
parser.add_argument(
'--device',
type=str,
default=None,
help='Device used for inference. '
'If not specified, the available device will be automatically used.')
parser.add_argument(
'--batch-size', type=int, default=1, help='Inference batch size.')
parser.add_argument(
'--show',
action='store_true',
help='Display the image in a popup window.')
parser.add_argument(
'--print-result',
action='store_true',
help='Whether to print the results.')
parser.add_argument(
'--save_pred',
action='store_true',
help='Save the inference results to out_dir.')
parser.add_argument(
'--save_vis',
action='store_true',
help='Save the visualization results to out_dir.')
call_args = vars(parser.parse_args())
init_kws = [
'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights', 'device'
]
init_args = {}
for init_kw in init_kws:
init_args[init_kw] = call_args.pop(init_kw)
return init_args, call_args
if __name__ == '__main__':
# Define Gradio input and output types
input_image = gr.inputs.Image(type="pil", label="Input Image")
out_dir = gr.inputs.Textbox(default="results")
det = gr.inputs.Dropdown(label="Text Detection Model", choices=[m for m in textdet_model_list], default='DBNet')
det_weights = gr.inputs.Textbox(default=None)
rec = gr.inputs.Dropdown(label="Text Recognition Model", choices=[m for m in textrec_model_list], default='CRNN')
rec_weights = gr.inputs.Textbox(default=None)
device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu")
batch_size = gr.inputs.Number(default=1, label="Inference batch size")
output_image = gr.outputs.Image(type="pil", label="Output Image")
output_json = gr.outputs.Textbox()
download_test_image()
examples = [["demo_text_ocr.jpg", "results", "DBNet", None, "CRNN", "cpu"],
["demo_text_det.jpg", "results", "FCENet", None, "ASTER", "cpu"],
["demo_text_recog.jpg", "results", "FCENet", None, "MASTER", "cpu"],
]
title = "MMOCR web demo"
description = "<div align='center'><img src='https://raw.githubusercontent.com/open-mmlab/mmocr/main/resources/mmocr-logo.png' width='450''/><div>" \
"<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a> MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检测,文本识别以及相应的下游任务,如关键信息提取。 它是 OpenMMLab 项目的一部分。" \
"OpenMMLab Text Detection, Recognition and Understanding Toolbox.</p>"
article = "<p style='text-align: center'><a href='https://github.com/open-mmlab/mmocr'>MMOCR</a></p>" \
"<p style='text-align: center'><a href='https://github.com/isLinXu'>gradio build by gatilin</a></a></p>"
# Create Gradio interface
iface = gr.Interface(
fn=ocr_inference,
inputs=[
input_image, out_dir, det, det_weights, rec, rec_weights, device
],
outputs=[output_image, output_json], examples=examples,
title=title, description=description, article=article,
)
# Launch Gradio interface
iface.launch()
|