from kornia.contrib import connected_components import torch import pdb import matplotlib.pyplot as plt import time # def reorder_int_labels(x): # _, y = torch.unique(x, return_inverse=True) # y -= y.min() # return y # def label_connected_component(labels, max_area=500, min_area=20, max_ccs=128, num_iterations=500): # assert len(labels.size()) == 2 # # per-label binary mask # unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1] # binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] # # label connected components # cc = connected_components(binary_masks.unsqueeze(1), num_iterations=num_iterations) # [?, 1, H, W] # cc = reorder_int_labels(cc) # bincount = torch.bincount(cc.long().flatten()) # # find all connected components (id, mask, area, valid) # # cc_id = torch.nonzero(bincount) # [num_cc] # cc_id = torch.argsort(bincount)[-max_ccs:] # cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] # cc_area = bincount[cc_id] # [num_cc] # valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] # valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1] # # final labels for connected component # out = valid * cc_mask # out = out.argmax(0) # return out def reorder_int_labels(x): _, y = torch.unique(x, return_inverse=True) y -= y.min() return y def label_connected_component(labels, min_area=20, topk=256): size = labels.size() assert len(size) == 2 max_area = size[0] * size[1] - 1 # per-label binary mask unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1], where ? is the number of unique id binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] # label connected components # cc is an integer tensor, each unique id represents a single connected component cc = connected_components(binary_masks.unsqueeze(1), num_iterations=500) # [?, 1, H, W] # reorder indices in cc so that cc_area tensor below is a smaller cc = reorder_int_labels(cc) # area of each connected components cc_area = torch.bincount(cc.long().flatten().cpu()).cuda() # bincount on GPU is much slower num_cc = cc_area.shape[0] valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] if num_cc < topk: selected_cc = torch.arange(num_cc).cuda() else: _, selected_cc = torch.topk(cc_area, k=topk) valid = valid[selected_cc] # collapse the 0th dimension, since there is only matched one connected component (across 0th dimension) cc_mask = (cc == selected_cc.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] cc_mask = cc_mask * valid.reshape(-1, 1, 1) out = cc_mask.argmax(0) return out # def reorder_int_labels(x): # _, y = torch.unique(x, return_inverse=True) # y -= y.min() # return y # def label_connected_component(labels, max_area=500, min_area=20): # assert len(labels.size()) == 2 # # per-label binary mask # unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1] # binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W] # # label connected components # cc = connected_components(binary_masks.unsqueeze(1)) # [?, 1, H, W] # cc = reorder_int_labels(cc) # bincount = torch.bincount(cc.long().flatten()) # # find all connected components (id, mask, area, valid) # cc_id = torch.nonzero(bincount) # [num_cc] # cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W] # cc_area = bincount[cc_id] # [num_cc] # valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc] # valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1] # # final labels for connected component # out = valid * cc_mask # out = out.argmax(0)