File size: 3,467 Bytes
01bb3bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import torch
import cv2
import numpy as np
import time
import warnings


import IndicPhotoOCR.detection.east_config as cfg
from IndicPhotoOCR.detection.east_utils import ModelManager
from IndicPhotoOCR.detection.east_model import East
import IndicPhotoOCR.detection.east_utils as utils

# Suppress warnings
warnings.filterwarnings("ignore")

class EASTdetector:
    def __init__(self, model_name= "east", model_path=None):
        self.model_path = model_path
        # self.model_manager = ModelManager()
        # self.model_manager.ensure_model(model_name)
        # self.ensure_model(self.model_name)
        # self.root_model_dir = "BharatSTR/bharatOCR/detection/East/tmp"

    def detect(self, image_path, model_checkpoint, device):
        # Load image
        im = cv2.imread(image_path)
        # im = cv2.imread(image_path)[:, :, ::-1]

        # Initialize the EAST model and load checkpoint
        model = East()
        model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)

        # Load the model checkpoint with weights_only=True
        checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True)
        model.load_state_dict(checkpoint['state_dict'])
        model.eval()

        # Resize image and convert to tensor format
        im_resized, (ratio_h, ratio_w) = utils.resize_image(im)
        im_resized = im_resized.astype(np.float32).transpose(2, 0, 1)
        im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu()

        # Inference
        timer = {'net': 0, 'restore': 0, 'nms': 0}
        start = time.time()
        score, geometry = model(im_tensor)
        timer['net'] = time.time() - start

        # Process output
        score = score.permute(0, 2, 3, 1).data.cpu().numpy()
        geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy()
        
        # Detect boxes
        boxes, timer = utils.detect(
            score_map=score, geo_map=geometry, timer=timer,
            score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh,
            nms_thres=cfg.box_thresh
        )
        bbox_result_dict = {'detections': []}

        # Parse detected boxes and adjust coordinates
        if boxes is not None:
            boxes = boxes[:, :8].reshape((-1, 4, 2))
            boxes[:, :, 0] /= ratio_w
            boxes[:, :, 1] /= ratio_h
            for box in boxes:
                box = utils.sort_poly(box.astype(np.int32))
                if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
                    continue
                bbox_result_dict['detections'].append([
                    [int(coord[0]), int(coord[1])] for coord in box
                ])

        return bbox_result_dict

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Text detection using EAST model')
    parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
    parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
    parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
    args = parser.parse_args()

    # Run prediction and get results as dictionary
    east = EASTdetector(model_path = args.model_checkpoint)
    detection_result = east.detect(args.image_path, args.model_checkpoint, args.device)
    # print(detection_result)