Spaces:
Running
on
T4
Running
on
T4
import cv2 | |
import numpy as np | |
import torch | |
from mmdet.registry import VISUALIZERS | |
class SegMaskHelper: | |
def __init__(self): | |
pass | |
# Pad the masks to image size (bug in RTMDet config?) | |
# @timer_func | |
def align_masks_with_image(self, result, img): | |
masks = list() | |
img = img[..., ::-1].copy() | |
for j, mask in enumerate(result.pred_instances.masks): | |
numpy_mask = mask.cpu().numpy() | |
mask = cv2.resize( | |
numpy_mask.astype(np.uint8), | |
(img.shape[1], img.shape[0]), | |
interpolation=cv2.INTER_NEAREST, | |
) | |
# Pad the mask to match the size of the image | |
padded_mask = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8) | |
padded_mask[: mask.shape[0], : mask.shape[1]] = mask | |
mask = padded_mask | |
mask = torch.from_numpy(mask) | |
masks.append(mask) | |
stacked_masks = torch.stack(masks) | |
result.pred_instances.masks = stacked_masks | |
return result | |
# Crops the images using masks and put the cropped images on a white background | |
# @timer_func | |
def crop_masks(self, result, img): | |
cropped_imgs = list() | |
polygons = list() | |
for j, mask in enumerate(result.pred_instances.masks): | |
np_array = mask.cpu().numpy() | |
contours, _ = cv2.findContours( | |
np_array.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE | |
) # fix so only one contour (the largest one) is extracted | |
largest_contour = max(contours, key=cv2.contourArea) | |
epsilon = 0.003 * cv2.arcLength(largest_contour, True) | |
approx_poly = cv2.approxPolyDP(largest_contour, epsilon, True) | |
approx_poly = np.squeeze(approx_poly) | |
approx_poly = approx_poly.tolist() | |
polygons.append(approx_poly) | |
x, y, w, h = cv2.boundingRect(largest_contour) | |
# Crop masked region and put on white background | |
masked_region = img[y : y + h, x : x + w] | |
white_background = np.ones_like(masked_region) | |
white_background.fill(255) | |
masked_region_on_white = cv2.bitwise_and( | |
white_background, masked_region, mask=np_array.astype(np.uint8)[y : y + h, x : x + w] | |
) | |
cv2.bitwise_not(white_background, white_background, mask=np_array.astype(np.uint8)[y : y + h, x : x + w]) | |
res = white_background + masked_region_on_white | |
cropped_imgs.append(res) | |
return cropped_imgs, polygons | |
def visualize_result(self, result, img, model_visualizer): | |
visualizer = VISUALIZERS.build(model_visualizer) | |
visualizer.add_datasample("result", img, data_sample=result, draw_gt=False) | |
return visualizer.get_image() | |
def _translate_line_coords(self, region_mask, line_polygons): | |
region_mask = region_mask.cpu().numpy() | |
region_masks_binary = (region_mask > 0).astype(np.uint8) | |
box = cv2.boundingRect(region_masks_binary) | |
translated_line_polygons = [[[a + box[0], b + box[1]] for [a, b] in poly] for poly in line_polygons] | |
return translated_line_polygons | |