rahulvenkk
app.py updated
6dfcb0f
raw
history blame
3.86 kB
from kornia.contrib import connected_components
import torch
import pdb
import matplotlib.pyplot as plt
import time
# def reorder_int_labels(x):
# _, y = torch.unique(x, return_inverse=True)
# y -= y.min()
# return y
# def label_connected_component(labels, max_area=500, min_area=20, max_ccs=128, num_iterations=500):
# assert len(labels.size()) == 2
# # per-label binary mask
# unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1]
# binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W]
# # label connected components
# cc = connected_components(binary_masks.unsqueeze(1), num_iterations=num_iterations) # [?, 1, H, W]
# cc = reorder_int_labels(cc)
# bincount = torch.bincount(cc.long().flatten())
# # find all connected components (id, mask, area, valid)
# # cc_id = torch.nonzero(bincount) # [num_cc]
# cc_id = torch.argsort(bincount)[-max_ccs:]
# cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W]
# cc_area = bincount[cc_id] # [num_cc]
# valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc]
# valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1]
# # final labels for connected component
# out = valid * cc_mask
# out = out.argmax(0)
# return out
def reorder_int_labels(x):
_, y = torch.unique(x, return_inverse=True)
y -= y.min()
return y
def label_connected_component(labels, min_area=20, topk=256):
size = labels.size()
assert len(size) == 2
max_area = size[0] * size[1] - 1
# per-label binary mask
unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1], where ? is the number of unique id
binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W]
# label connected components
# cc is an integer tensor, each unique id represents a single connected component
cc = connected_components(binary_masks.unsqueeze(1), num_iterations=500) # [?, 1, H, W]
# reorder indices in cc so that cc_area tensor below is a smaller
cc = reorder_int_labels(cc)
# area of each connected components
cc_area = torch.bincount(cc.long().flatten().cpu()).cuda() # bincount on GPU is much slower
num_cc = cc_area.shape[0]
valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc]
if num_cc < topk:
selected_cc = torch.arange(num_cc).cuda()
else:
_, selected_cc = torch.topk(cc_area, k=topk)
valid = valid[selected_cc]
# collapse the 0th dimension, since there is only matched one connected component (across 0th dimension)
cc_mask = (cc == selected_cc.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W]
cc_mask = cc_mask * valid.reshape(-1, 1, 1)
out = cc_mask.argmax(0)
return out
# def reorder_int_labels(x):
# _, y = torch.unique(x, return_inverse=True)
# y -= y.min()
# return y
# def label_connected_component(labels, max_area=500, min_area=20):
# assert len(labels.size()) == 2
# # per-label binary mask
# unique_labels = torch.unique(labels).reshape(-1, 1, 1) # [?, 1, 1]
# binary_masks = (labels.unsqueeze(0) == unique_labels).float() # [?, H, W]
# # label connected components
# cc = connected_components(binary_masks.unsqueeze(1)) # [?, 1, H, W]
# cc = reorder_int_labels(cc)
# bincount = torch.bincount(cc.long().flatten())
# # find all connected components (id, mask, area, valid)
# cc_id = torch.nonzero(bincount) # [num_cc]
# cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W]
# cc_area = bincount[cc_id] # [num_cc]
# valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc]
# valid = valid.reshape(-1, 1, 1) # [num_cc, 1, 1]
# # final labels for connected component
# out = valid * cc_mask
# out = out.argmax(0)