CTRNetDemo / ctrnet_infer.py
SWHL's picture
First commit
38a5dfd
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import copy
import time
import cv2
import numpy as np
import pyclipper
from onnxruntime import InferenceSession
from shapely.geometry import Polygon
from rapid_ch_det import TextDetector
class SimpleDataset():
def __call__(self, img: np.ndarray, bboxes: np.ndarray):
'''
bboxes: (N, 4, 2)
'''
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
gt_instance = np.zeros(img.shape[:2], dtype='uint8')
for i in range(len(bboxes)):
cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
gt_text = gt_instance.copy()
gt_text[gt_text > 0] = 1
gt_text = gt_text[None, None, ...].astype(np.float32)
canvas, shrink_mask, mask_ori = self.get_seg_map(img, bboxes)
soft_mask = canvas + mask_ori
index_mask = np.where(soft_mask > 1)
soft_mask[index_mask] = 1
soft_mask = soft_mask[None, None, ...].astype(np.float32)
img = np.transpose(img, (2, 0, 1)).astype(np.float32) / 255.0
img = img[None, ...]
structure_im = copy.deepcopy(img)
return img, structure_im, gt_text, soft_mask
def draw_border_map(self, polygon, canvas, mask_ori, mask):
polygon = np.array(polygon)
assert polygon.ndim == 2
assert polygon.shape[1] == 2
### shrink box ###
polygon_shape = Polygon(polygon)
distance = polygon_shape.area * \
(1 - np.power(0.95, 2)) / polygon_shape.length
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = np.array(padding.Execute(-distance)[0])
cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
### shrink box ###
cv2.fillPoly(mask_ori, [polygon.astype(np.int32)], 1.0)
polygon = padded_polygon
polygon_shape = Polygon(padded_polygon)
distance = polygon_shape.area * \
(1 - np.power(0.4, 2)) / polygon_shape.length
subject = [tuple(l) for l in polygon]
padding = pyclipper.PyclipperOffset()
padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
padded_polygon = np.array(padding.Execute(distance)[0])
xmin = padded_polygon[:, 0].min()
xmax = padded_polygon[:, 0].max()
ymin = padded_polygon[:, 1].min()
ymax = padded_polygon[:, 1].max()
width = xmax - xmin + 1
height = ymax - ymin + 1
polygon[:, 0] = polygon[:, 0] - xmin
polygon[:, 1] = polygon[:, 1] - ymin
xs = np.broadcast_to(
np.linspace(0, width - 1, num=width).reshape(1, width), (height, width))
ys = np.broadcast_to(
np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width))
distance_map = np.zeros(
(polygon.shape[0], height, width), dtype=np.float32)
for i in range(polygon.shape[0]):
j = (i + 1) % polygon.shape[0]
# import pdb;pdb.set_trace()
absolute_distance = self.coumpute_distance(xs, ys, polygon[i], polygon[j])
distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
distance_map = distance_map.min(axis=0)
xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
1 - distance_map[
ymin_valid-ymin:ymax_valid-ymax+height,
xmin_valid-xmin:xmax_valid-xmax+width],
canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
@staticmethod
def coumpute_distance(xs, ys, point_1, point_2):
'''
compute the distance from point to a line
ys: coordinates in the first axis
xs: coordinates in the second axis
point_1, point_2: (x, y), the end of the line
'''
height, width = xs.shape[:2]
square_distance_1 = np.square(
xs - point_1[0]) + np.square(ys - point_1[1])
square_distance_2 = np.square(
xs - point_2[0]) + np.square(ys - point_2[1])
square_distance = np.square(
point_1[0] - point_2[0]) + np.square(point_1[1] - point_2[1])
cosin = (square_distance - square_distance_1 - square_distance_2) / \
(2 * np.sqrt(square_distance_1 * square_distance_2) + 1e-50)
square_sin = 1 - np.square(cosin)
square_sin = np.nan_to_num(square_sin)
result = np.sqrt(square_distance_1 * square_distance_2 *
square_sin / square_distance)
result[cosin < 0] = np.sqrt(np.fmin(
square_distance_1, square_distance_2))[cosin < 0]
# extend_line(point_1, point_2, result)
return result
def get_seg_map(self, img, label):
canvas = np.zeros(img.shape[:2], dtype=np.float32)
mask = np.zeros(img.shape[:2], dtype=np.float32)
mask_ori = np.zeros(img.shape[:2], dtype=np.float32)
polygons = label
for i in range(len(polygons)):
self.draw_border_map(polygons[i], canvas, mask_ori, mask=mask)
return canvas, mask, mask_ori
class CTRNetInfer():
def __init__(self, model_path) -> None:
self.session = InferenceSession(model_path,
providers=['CPUExecutionProvider'])
self.dataset = SimpleDataset()
self.text_det = TextDetector()
self.input_shape = (512, 512)
def __call__(self, ori_img):
ori_img_shape = ori_img.shape[:2]
bboxes = self.text_det(ori_img)[0].astype(np.int64)
# resize img 到512x512
resize_img = cv2.resize(ori_img, self.input_shape,
interpolation=cv2.INTER_LINEAR)
resize_bboxes = self.get_resized_points(bboxes,
ori_img_shape,
self.input_shape)
img, structure_im, gt_text, soft_mask = self.dataset(
resize_img, resize_bboxes)
input_dict = {
'input': img,
'gt_text': gt_text,
'soft_mask': soft_mask,
'structure_im': structure_im
}
prediction = self.session.run(None, input_dict)[3]
withMask_prediction = prediction * soft_mask + img * (1 - soft_mask)
withMask_prediction = np.transpose(withMask_prediction, (0, 2, 3, 1)) * 255
withMask_prediction = withMask_prediction.squeeze().astype(np.uint8)
withMask_prediction = cv2.cvtColor(withMask_prediction,
cv2.COLOR_BGR2RGB)
ori_pred = cv2.resize(withMask_prediction, ori_img_shape[::-1],
interpolation=cv2.INTER_LINEAR)
return ori_pred
@staticmethod
def get_resized_points(cur_points, cur_shape, new_shape):
cur_points = np.array(cur_points)
ratio_x = cur_shape[0] / new_shape[0]
ratio_y = cur_shape[1] / new_shape[1]
cur_points[:, :, 0] = cur_points[:, :, 0] / ratio_x
cur_points[:, :, 1] = cur_points[:, :, 1] / ratio_y
return cur_points.astype(np.int64)
if __name__ == '__main__':
model_path = 'CTRNet_G.onnx'
ctrnet = CTRNetInfer(model_path)
img_path = 'images/1.jpg'
ori_img = cv2.imread(img_path)
s = time.time()
pred = ctrnet(ori_img)
print(f'elapse: {time.time() - s}')
cv2.imwrite('pred_result.jpg', pred)