File size: 3,857 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from kornia.contrib import connected_components
import torch
import pdb
import matplotlib.pyplot as plt
import time

# def reorder_int_labels(x):
#     _, y = torch.unique(x, return_inverse=True)
#     y -= y.min()
#     return y

# def label_connected_component(labels, max_area=500, min_area=20, max_ccs=128, num_iterations=500):

#     assert len(labels.size()) == 2

#     # per-label binary mask
#     unique_labels = torch.unique(labels).reshape(-1, 1, 1)  # [?, 1, 1]
#     binary_masks = (labels.unsqueeze(0) == unique_labels).float()  # [?, H, W]

#     # label connected components
#     cc = connected_components(binary_masks.unsqueeze(1), num_iterations=num_iterations)  # [?, 1, H, W]
#     cc = reorder_int_labels(cc)
#     bincount = torch.bincount(cc.long().flatten())

#     # find all connected components (id, mask, area, valid)
#     # cc_id = torch.nonzero(bincount)  # [num_cc]
#     cc_id = torch.argsort(bincount)[-max_ccs:]
#     cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W]
#     cc_area = bincount[cc_id]  # [num_cc]
#     valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc]
#     valid = valid.reshape(-1, 1, 1)  # [num_cc, 1, 1]

#     # final labels for connected component
#     out = valid * cc_mask
#     out = out.argmax(0)
#     return out

def reorder_int_labels(x):
    _, y = torch.unique(x, return_inverse=True)
    y -= y.min()
    return y


def label_connected_component(labels, min_area=20, topk=256):
    size = labels.size()
    assert len(size) == 2
    max_area = size[0] * size[1] - 1

    # per-label binary mask
    unique_labels = torch.unique(labels).reshape(-1, 1, 1)  # [?, 1, 1], where ? is the number of unique id
    binary_masks = (labels.unsqueeze(0) == unique_labels).float()  # [?, H, W]

    # label connected components
    # cc is an integer tensor, each unique id represents a single connected component
    cc = connected_components(binary_masks.unsqueeze(1), num_iterations=500)  # [?, 1, H, W]

    # reorder indices in cc so that cc_area tensor below is a smaller
    cc = reorder_int_labels(cc)

    # area of each connected components
    cc_area = torch.bincount(cc.long().flatten().cpu()).cuda()  # bincount on GPU is much slower
    num_cc = cc_area.shape[0]
    valid = (cc_area >= min_area) & (cc_area <= max_area)  # [num_cc]

    if num_cc < topk:
        selected_cc = torch.arange(num_cc).cuda()
    else:
        _, selected_cc = torch.topk(cc_area, k=topk)
        valid = valid[selected_cc]

    # collapse the 0th dimension, since there is only matched one connected component (across 0th dimension)
    cc_mask = (cc == selected_cc.reshape(1, -1, 1, 1)).sum(0)  # [num_cc, H, W]
    cc_mask = cc_mask * valid.reshape(-1, 1, 1)
    out = cc_mask.argmax(0)
    return out


# def reorder_int_labels(x):
#     _, y = torch.unique(x, return_inverse=True)
#     y -= y.min()
#     return y

# def label_connected_component(labels, max_area=500, min_area=20):

#     assert len(labels.size()) == 2

#     # per-label binary mask
#     unique_labels = torch.unique(labels).reshape(-1, 1, 1)  # [?, 1, 1]
#     binary_masks = (labels.unsqueeze(0) == unique_labels).float()  # [?, H, W]

#     # label connected components
#     cc = connected_components(binary_masks.unsqueeze(1))  # [?, 1, H, W]
#     cc = reorder_int_labels(cc)
#     bincount = torch.bincount(cc.long().flatten())

#     # find all connected components (id, mask, area, valid)
#     cc_id = torch.nonzero(bincount)  # [num_cc]
#     cc_mask = (cc == cc_id.reshape(1, -1, 1, 1)).sum(0) # [num_cc, H, W]
#     cc_area = bincount[cc_id]  # [num_cc]
#     valid = (cc_area >= min_area) & (cc_area <= max_area) # [num_cc]
#     valid = valid.reshape(-1, 1, 1)  # [num_cc, 1, 1]

#     # final labels for connected component
#     out = valid * cc_mask
#     out = out.argmax(0)