shreyasvaidya's picture
Upload folder using huggingface_hub
01bb3bb verified
import os
import torch
import cv2
import numpy as np
import time
import warnings
import IndicPhotoOCR.detection.east_config as cfg
from IndicPhotoOCR.detection.east_utils import ModelManager
from IndicPhotoOCR.detection.east_model import East
import IndicPhotoOCR.detection.east_utils as utils
# Suppress warnings
warnings.filterwarnings("ignore")
class EASTdetector:
def __init__(self, model_name= "east", model_path=None):
self.model_path = model_path
# self.model_manager = ModelManager()
# self.model_manager.ensure_model(model_name)
# self.ensure_model(self.model_name)
# self.root_model_dir = "BharatSTR/bharatOCR/detection/East/tmp"
def detect(self, image_path, model_checkpoint, device):
# Load image
im = cv2.imread(image_path)
# im = cv2.imread(image_path)[:, :, ::-1]
# Initialize the EAST model and load checkpoint
model = East()
model = torch.nn.DataParallel(model, device_ids=cfg.gpu_ids)
# Load the model checkpoint with weights_only=True
checkpoint = torch.load(model_checkpoint, map_location=torch.device(device), weights_only=True)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
# Resize image and convert to tensor format
im_resized, (ratio_h, ratio_w) = utils.resize_image(im)
im_resized = im_resized.astype(np.float32).transpose(2, 0, 1)
im_tensor = torch.from_numpy(im_resized).unsqueeze(0).cpu()
# Inference
timer = {'net': 0, 'restore': 0, 'nms': 0}
start = time.time()
score, geometry = model(im_tensor)
timer['net'] = time.time() - start
# Process output
score = score.permute(0, 2, 3, 1).data.cpu().numpy()
geometry = geometry.permute(0, 2, 3, 1).data.cpu().numpy()
# Detect boxes
boxes, timer = utils.detect(
score_map=score, geo_map=geometry, timer=timer,
score_map_thresh=cfg.score_map_thresh, box_thresh=cfg.box_thresh,
nms_thres=cfg.box_thresh
)
bbox_result_dict = {'detections': []}
# Parse detected boxes and adjust coordinates
if boxes is not None:
boxes = boxes[:, :8].reshape((-1, 4, 2))
boxes[:, :, 0] /= ratio_w
boxes[:, :, 1] /= ratio_h
for box in boxes:
box = utils.sort_poly(box.astype(np.int32))
if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3] - box[0]) < 5:
continue
bbox_result_dict['detections'].append([
[int(coord[0]), int(coord[1])] for coord in box
])
return bbox_result_dict
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Text detection using EAST model')
parser.add_argument('--image_path', type=str, required=True, help='Path to the input image')
parser.add_argument('--device', type=str, default='cpu', help='Device to run the model on, e.g., "cpu" or "cuda"')
parser.add_argument('--model_checkpoint', type=str, required=True, help='Path to the model checkpoint file')
args = parser.parse_args()
# Run prediction and get results as dictionary
east = EASTdetector(model_path = args.model_checkpoint)
detection_result = east.detect(args.image_path, args.model_checkpoint, args.device)
# print(detection_result)