|
|
|
|
|
|
|
import contextlib |
|
import io |
|
import numpy as np |
|
import unittest |
|
from collections import defaultdict |
|
import torch |
|
import tqdm |
|
from fvcore.common.benchmark import benchmark |
|
from pycocotools.coco import COCO |
|
from tabulate import tabulate |
|
from torch.nn import functional as F |
|
|
|
from detectron2.data import MetadataCatalog |
|
from detectron2.layers.mask_ops import ( |
|
pad_masks, |
|
paste_mask_in_image_old, |
|
paste_masks_in_image, |
|
scale_boxes, |
|
) |
|
from detectron2.structures import BitMasks, Boxes, BoxMode, PolygonMasks |
|
from detectron2.structures.masks import polygons_to_bitmask |
|
from detectron2.utils.file_io import PathManager |
|
from detectron2.utils.testing import random_boxes |
|
|
|
|
|
def iou_between_full_image_bit_masks(a, b): |
|
intersect = (a & b).sum() |
|
union = (a | b).sum() |
|
return intersect / union |
|
|
|
|
|
def rasterize_polygons_with_grid_sample(full_image_bit_mask, box, mask_size, threshold=0.5): |
|
x0, y0, x1, y1 = box[0], box[1], box[2], box[3] |
|
|
|
img_h, img_w = full_image_bit_mask.shape |
|
|
|
mask_y = np.arange(0.0, mask_size) + 0.5 |
|
mask_x = np.arange(0.0, mask_size) + 0.5 |
|
mask_y = mask_y / mask_size * (y1 - y0) + y0 |
|
mask_x = mask_x / mask_size * (x1 - x0) + x0 |
|
|
|
mask_x = (mask_x - 0.5) / (img_w - 1) * 2 + -1 |
|
mask_y = (mask_y - 0.5) / (img_h - 1) * 2 + -1 |
|
gy, gx = torch.meshgrid(torch.from_numpy(mask_y), torch.from_numpy(mask_x)) |
|
ind = torch.stack([gx, gy], dim=-1).to(dtype=torch.float32) |
|
|
|
full_image_bit_mask = torch.from_numpy(full_image_bit_mask) |
|
mask = F.grid_sample( |
|
full_image_bit_mask[None, None, :, :].to(dtype=torch.float32), |
|
ind[None, :, :, :], |
|
align_corners=True, |
|
) |
|
|
|
return mask[0, 0] >= threshold |
|
|
|
|
|
class TestMaskCropPaste(unittest.TestCase): |
|
def setUp(self): |
|
json_file = MetadataCatalog.get("coco_2017_val_100").json_file |
|
if not PathManager.isfile(json_file): |
|
raise unittest.SkipTest("{} not found".format(json_file)) |
|
with contextlib.redirect_stdout(io.StringIO()): |
|
json_file = PathManager.get_local_path(json_file) |
|
self.coco = COCO(json_file) |
|
|
|
def test_crop_paste_consistency(self): |
|
""" |
|
rasterize_polygons_within_box (used in training) |
|
and |
|
paste_masks_in_image (used in inference) |
|
should be inverse operations to each other. |
|
|
|
This function runs several implementation of the above two operations and prints |
|
the reconstruction error. |
|
""" |
|
|
|
anns = self.coco.loadAnns(self.coco.getAnnIds(iscrowd=False)) |
|
|
|
selected_anns = anns[:100] |
|
|
|
ious = [] |
|
for ann in tqdm.tqdm(selected_anns): |
|
results = self.process_annotation(ann) |
|
ious.append([k[2] for k in results]) |
|
|
|
ious = np.array(ious) |
|
mean_ious = ious.mean(axis=0) |
|
table = [] |
|
res_dic = defaultdict(dict) |
|
for row, iou in zip(results, mean_ious): |
|
table.append((row[0], row[1], iou)) |
|
res_dic[row[0]][row[1]] = iou |
|
print(tabulate(table, headers=["rasterize", "paste", "iou"], tablefmt="simple")) |
|
|
|
self.assertTrue(res_dic["polygon"]["aligned"] > 0.94) |
|
self.assertTrue(res_dic["roialign"]["aligned"] > 0.95) |
|
|
|
def process_annotation(self, ann, mask_side_len=28): |
|
|
|
img_info = self.coco.loadImgs(ids=[ann["image_id"]])[0] |
|
height, width = img_info["height"], img_info["width"] |
|
gt_polygons = [np.array(p, dtype=np.float64) for p in ann["segmentation"]] |
|
gt_bbox = BoxMode.convert(ann["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) |
|
gt_bit_mask = polygons_to_bitmask(gt_polygons, height, width) |
|
|
|
|
|
torch_gt_bbox = torch.tensor(gt_bbox).to(dtype=torch.float32).reshape(-1, 4) |
|
box_bitmasks = { |
|
"polygon": PolygonMasks([gt_polygons]).crop_and_resize(torch_gt_bbox, mask_side_len)[0], |
|
"gridsample": rasterize_polygons_with_grid_sample(gt_bit_mask, gt_bbox, mask_side_len), |
|
"roialign": BitMasks(torch.from_numpy(gt_bit_mask[None, :, :])).crop_and_resize( |
|
torch_gt_bbox, mask_side_len |
|
)[0], |
|
} |
|
|
|
|
|
results = defaultdict(dict) |
|
for k, box_bitmask in box_bitmasks.items(): |
|
padded_bitmask, scale = pad_masks(box_bitmask[None, :, :], 1) |
|
scaled_boxes = scale_boxes(torch_gt_bbox, scale) |
|
|
|
r = results[k] |
|
r["old"] = paste_mask_in_image_old( |
|
padded_bitmask[0], scaled_boxes[0], height, width, threshold=0.5 |
|
) |
|
r["aligned"] = paste_masks_in_image( |
|
box_bitmask[None, :, :], Boxes(torch_gt_bbox), (height, width) |
|
)[0] |
|
|
|
table = [] |
|
for rasterize_method, r in results.items(): |
|
for paste_method, mask in r.items(): |
|
mask = np.asarray(mask) |
|
iou = iou_between_full_image_bit_masks(gt_bit_mask.astype("uint8"), mask) |
|
table.append((rasterize_method, paste_method, iou)) |
|
return table |
|
|
|
def test_polygon_area(self): |
|
|
|
for d in [5.0, 10.0, 1000.0]: |
|
polygon = PolygonMasks([[[0, 0, 0, d, d, d, d, 0]]]) |
|
area = polygon.area()[0] |
|
target = d**2 |
|
self.assertEqual(area, target) |
|
|
|
|
|
for d in [5.0, 10.0, 1000.0]: |
|
polygon = PolygonMasks([[[0, 0, 0, d, d, d]]]) |
|
area = polygon.area()[0] |
|
target = d**2 / 2 |
|
self.assertEqual(area, target) |
|
|
|
def test_paste_mask_scriptable(self): |
|
scripted_f = torch.jit.script(paste_masks_in_image) |
|
N = 10 |
|
masks = torch.rand(N, 28, 28) |
|
boxes = Boxes(random_boxes(N, 100)).tensor |
|
image_shape = (150, 150) |
|
|
|
out = paste_masks_in_image(masks, boxes, image_shape) |
|
scripted_out = scripted_f(masks, boxes, image_shape) |
|
self.assertTrue(torch.equal(out, scripted_out)) |
|
|
|
|
|
def benchmark_paste(): |
|
S = 800 |
|
H, W = image_shape = (S, S) |
|
N = 64 |
|
torch.manual_seed(42) |
|
masks = torch.rand(N, 28, 28) |
|
|
|
center = torch.rand(N, 2) * 600 + 100 |
|
wh = torch.clamp(torch.randn(N, 2) * 40 + 200, min=50) |
|
x0y0 = torch.clamp(center - wh * 0.5, min=0.0) |
|
x1y1 = torch.clamp(center + wh * 0.5, max=S) |
|
boxes = Boxes(torch.cat([x0y0, x1y1], axis=1)) |
|
|
|
def func(device, n=3): |
|
m = masks.to(device=device) |
|
b = boxes.to(device=device) |
|
|
|
def bench(): |
|
for _ in range(n): |
|
paste_masks_in_image(m, b, image_shape) |
|
if device.type == "cuda": |
|
torch.cuda.synchronize() |
|
|
|
return bench |
|
|
|
specs = [{"device": torch.device("cpu"), "n": 3}] |
|
if torch.cuda.is_available(): |
|
specs.append({"device": torch.device("cuda"), "n": 3}) |
|
|
|
benchmark(func, "paste_masks", specs, num_iters=10, warmup_iters=2) |
|
|
|
|
|
if __name__ == "__main__": |
|
benchmark_paste() |
|
unittest.main() |
|
|