|
import torch |
|
import onnx |
|
import onnxruntime as rt |
|
from torchvision import transforms as T |
|
from PIL import Image |
|
from tokenizer_base import Tokenizer |
|
import pathlib |
|
import os |
|
import gradio as gr |
|
from huggingface_hub import Repository |
|
|
|
repo = Repository( |
|
local_dir="secret_models", |
|
repo_type="model", |
|
clone_from="docparser/captcha", |
|
token=True |
|
) |
|
repo.git_pull() |
|
|
|
cwd = pathlib.Path(__file__).parent.resolve() |
|
model_file = os.path.join(cwd,"secret_models","captcha.onnx") |
|
img_size = (32,128) |
|
charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" |
|
tokenizer_base = Tokenizer(charset) |
|
|
|
def get_transform(img_size): |
|
transforms = [] |
|
transforms.extend([ |
|
T.Resize(img_size, T.InterpolationMode.BICUBIC), |
|
T.ToTensor(), |
|
T.Normalize(0.5, 0.5) |
|
]) |
|
return T.Compose(transforms) |
|
|
|
def to_numpy(tensor): |
|
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() |
|
|
|
def initialize_model(model_file): |
|
transform = get_transform(img_size) |
|
|
|
onnx_model = onnx.load(model_file) |
|
onnx.checker.check_model(onnx_model) |
|
ort_session = rt.InferenceSession(model_file) |
|
return transform,ort_session |
|
|
|
def get_text(img_org): |
|
|
|
|
|
x = transform(img_org.convert('RGB')).unsqueeze(0) |
|
|
|
|
|
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} |
|
logits = ort_session.run(None, ort_inputs)[0] |
|
probs = torch.tensor(logits).softmax(-1) |
|
preds, probs = tokenizer_base.decode(probs) |
|
preds = preds[0] |
|
print(preds) |
|
return preds |
|
|
|
transform,ort_session = initialize_model(model_file=model_file) |
|
|
|
gr.Interface( |
|
get_text, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.outputs.Textbox(), |
|
title="Text Captcha Reader", |
|
examples=["8000.png","11JW29.png","2a8486.jpg","2nbcx.png", |
|
"000679.png","000HU.png","00Uga.png.jpg","00bAQwhAZU.jpg", |
|
"00h57kYf.jpg","0EoHdtVb.png","0JS21.png","0p98z.png","10010.png"] |
|
).launch() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|