import os os.system("pip install gradio==3.42.0") os.system("pip install 'mmengine>=0.6.0'") os.system("pip install 'mmcv>=2.0.0rc4,<2.1.0'") os.system("pip install 'mmdet>=3.0.0rc5, < 3.2.0'") os.system("pip install mmocr") import json import os from argparse import ArgumentParser import PIL import cv2 import gradio as gr import numpy as np import torch from PIL.Image import Image from mmocr.apis.inferencers import MMOCRInferencer import warnings warnings.filterwarnings("ignore") def save_image(img, img_path): # Convert PIL image to OpenCV image img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) # Save OpenCV image cv2.imwrite(img_path, img) textdet_model_list = ['DBNet', 'DRRG', 'FCENet', 'PANet', 'PSENet', 'TextSnake', 'MaskRCNN'] textrec_model_list = ['ABINet', 'ASTER', 'CRNN', 'MASTER', 'NRTR', 'RobustScanner', 'SARNet', 'SATRN', 'SVTR'] textkie_model_list = ['SDMGR'] def ocr_inference(inputs, out_dir, det, det_weights, rec, rec_weights, kie, kie_weights, device): init_args, call_args = parse_args() inputs = np.array(inputs) img_path = "demo_text_ocr.jpg" save_image(inputs, img_path) if det is not None and rec is not None: init_args['det'] = det init_args['det_weights'] = None init_args['rec'] = rec init_args['rec_weights'] = None elif det_weights is not None and rec_weights is not None: init_args['det'] = None init_args['det_weights'] = det_weights init_args['rec'] = None init_args['rec_weights'] = rec_weights call_args['inputs'] = img_path call_args['out_dir'] = out_dir call_args['batch_size'] = 1 call_args['show'] = False call_args['save_pred'] = True call_args['save_vis'] = True init_args['device'] = device print("init_args", init_args) print("call_args", call_args) ocr = MMOCRInferencer(**init_args) ocr(**call_args) save_vis_dir = './results/vis/' save_pred_dir = './results/preds/' img_out = PIL.Image.open(os.path.join(save_vis_dir, img_path)) json_out = json.load(open(os.path.join(save_pred_dir, img_path.replace('.jpg', '.json')))) return img_out, json_out def download_test_image(): # Images torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266821429-9a897c0a-5b02-4260-a65b-3514b758f6b6.jpg', 'demo_densetext_det.jpg') torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266821432-17bb0646-a3e9-451e-9b4d-6e41ce4c3f0c.jpg', 'demo_text_recog.jpg') torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266821434-fe0d4d18-f3e2-4acf-baf5-0d2e318f0b09.jpg', 'demo_text_ocr.jpg') torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266821435-5d7af2b4-cb84-4355-91cb-37d90e91aa30.jpg', 'demo_text_det.jpg') torch.hub.download_url_to_file( 'https://user-images.githubusercontent.com/59380685/266821436-4790c6c1-2da5-45c7-b837-04eeea0d7264.jpeg', 'demo_kie.jpg') def parse_args(): parser = ArgumentParser() parser.add_argument( '--inputs', type=str, help='Input image file or folder path.') parser.add_argument( '--out-dir', type=str, default='./results/', help='Output directory of results.') parser.add_argument( '--det', type=str, default=None, help='Pretrained text detection algorithm. It\'s the path to the ' 'config file or the model name defined in metafile.') parser.add_argument( '--det-weights', type=str, default=None, help='Path to the custom checkpoint file of the selected det model. ' 'If it is not specified and "det" is a model name of metafile, the ' 'weights will be loaded from metafile.') parser.add_argument( '--rec', type=str, default=None, help='Pretrained text recognition algorithm. It\'s the path to the ' 'config file or the model name defined in metafile.') parser.add_argument( '--rec-weights', type=str, default=None, help='Path to the custom checkpoint file of the selected recog model. ' 'If it is not specified and "rec" is a model name of metafile, the ' 'weights will be loaded from metafile.') parser.add_argument( '--kie', type=str, default=None, help='Pretrained key information extraction algorithm. It\'s the path' 'to the config file or the model name defined in metafile.') parser.add_argument( '--kie-weights', type=str, default=None, help='Path to the custom checkpoint file of the selected kie model. ' 'If it is not specified and "kie" is a model name of metafile, the ' 'weights will be loaded from metafile.') parser.add_argument( '--device', type=str, default=None, help='Device used for inference. ' 'If not specified, the available device will be automatically used.') parser.add_argument( '--batch-size', type=int, default=1, help='Inference batch size.') parser.add_argument( '--show', action='store_true', help='Display the image in a popup window.') parser.add_argument( '--print-result', action='store_true', help='Whether to print the results.') parser.add_argument( '--save_pred', action='store_true', help='Save the inference results to out_dir.') parser.add_argument( '--save_vis', action='store_true', help='Save the visualization results to out_dir.') call_args = vars(parser.parse_args()) init_kws = [ 'det', 'det_weights', 'rec', 'rec_weights', 'kie', 'kie_weights', 'device' ] init_args = {} for init_kw in init_kws: init_args[init_kw] = call_args.pop(init_kw) return init_args, call_args if __name__ == '__main__': # Define Gradio input and output types input_image = gr.inputs.Image(type="pil", label="Input Image") out_dir = gr.inputs.Textbox(default="results") det = gr.inputs.Dropdown(label="Text Detection Model", choices=[m for m in textdet_model_list], default='DBNet') det_weights = gr.inputs.Textbox(default=None) rec = gr.inputs.Dropdown(label="Text Recognition Model", choices=[m for m in textrec_model_list], default='CRNN') rec_weights = gr.inputs.Textbox(default=None) device = gr.inputs.Radio(choices=["cpu", "cuda"], label="Device used for inference", default="cpu") batch_size = gr.inputs.Number(default=1, label="Inference batch size") output_image = gr.outputs.Image(type="pil", label="Output Image") output_json = gr.outputs.Textbox() download_test_image() examples = [["demo_text_ocr.jpg", "results", "DBNet", None, "CRNN", "cpu"], ["demo_text_det.jpg", "results", "FCENet", None, "ASTER", "cpu"], ["demo_text_recog.jpg", "results", "FCENet", None, "MASTER", "cpu"], ] title = "MMOCR web demo" description = "
MMOCR MMOCR 是基于 PyTorch 和 mmdetection 的开源工具箱,专注于文本检测,文本识别以及相应的下游任务,如关键信息提取。 它是 OpenMMLab 项目的一部分。" \ "OpenMMLab Text Detection, Recognition and Understanding Toolbox.
" article = "" \ "" # Create Gradio interface iface = gr.Interface( fn=ocr_inference, inputs=[ input_image, out_dir, det, det_weights, rec, rec_weights, device ], outputs=[output_image, output_json], examples=examples, title=title, description=description, article=article, ) # Launch Gradio interface iface.launch()