|
import os |
|
import torch |
|
import cv2 |
|
import numpy as np |
|
import time |
|
import warnings |
|
|
|
|
|
import IndicPhotoOCR.detection.east_config as cfg |
|
from IndicPhotoOCR.detection.east_utils import ModelManager |
|
from IndicPhotoOCR.detection.east_model import East |
|
import IndicPhotoOCR.detection.east_utils as utils |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
class EASTdetector: |
|
def __init__(self, model_name= "east", model_path=None): |
|
self.model_path = model_path |
|
|
|
|
|
|
|
|
|
|
|
def detect(self, image_path, model_checkpoint, device): |
|
|
|
im = cv2.imread(image_path) |
|
|
|
|
|
|
|
model = East() |
|
model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids) |
|
|
|
|
|
checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True) |
|
model.load_state_dict(checkpoint['state_dict']) |
|
model.eval() |
|
|
|
|
|
im_resized, (ratio_h, ratio_w) = utils.resize_image(im) |
|
im_resized = im_resized.astype(np.float32).transpose(2, 0, 1) |
|
im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu() |
|
|
|
|
|
timer = {'net': 0, 'restore': 0, 'nms': 0} |
|
start = time.time() |
|
score, geometry = model(im_tensor) |
|
timer['net'] = time.time() - start |
|
|
|
|
|
score = score.permute(0, 2, 3, 1).data.cpu().numpy() |
|
geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy() |
|
|
|
|
|
boxes, timer = utils.detect( |
|
score_map=score, geo_map=geometry, timer=timer, |
|
score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh, |
|
nms_thres=cfg.box_thresh |
|
) |
|
bbox_result_dict = {'detections': []} |
|
|
|
|
|
if boxes is not None: |
|
boxes = boxes[:, :8].reshape((-1, 4, 2)) |
|
boxes[:, :, 0] /= ratio_w |
|
boxes[:, :, 1] /= ratio_h |
|
for box in boxes: |
|
box = utils.sort_poly(box.astype(np.int32)) |
|
if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5: |
|
continue |
|
bbox_result_dict['detections'].append([ |
|
[int(coord[0]), int(coord[1])] for coord in box |
|
]) |
|
|
|
return bbox_result_dict |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser(description='Text detection using EAST model') |
|
parser.add_argument('--image_path', type=str, required=True, help='Path to the input image') |
|
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"') |
|
parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file') |
|
args = parser.parse_args() |
|
|
|
|
|
east = EASTdetector(model_path = args.model_checkpoint) |
|
detection_result = east.detect(args.image_path, args.model_checkpoint, args.device) |
|
|
|
|