File size: 11,947 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import torch
import os
import glob
import time
from torchvision.io import read_image
import matplotlib.pyplot as plt
from scipy import ndimage
from PIL import Image

#import bbnet.trainval.validator as validator
import modeling_pretrain_cleaned as vmae_transformers
import modeling_pretrain as vmae_transformers_old
#import positional_vmae as pos_transformers
#import big_models as big_transformers
import bbnet.models.preprocessor as preprocessor
import bbnet.models.error as error_generator
from functools import partial
#import bbnet.models.teachers as teachers
from tqdm import tqdm
from torch.nn import functional as F
import argparse
import sys
import numpy as np
import json
import pycocotools.mask as mask_util
sys.path.append('/ccn2/u/honglinc/CutLER')
sys.path.append('/ccn2/u/honglinc/CutLER/maskcut')
sys.path.append('/ccn2/u/honglinc/CutLER/third_party')
import dino
import maskcut
# from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix
from third_party.TokenCut.unsupervised_saliency_detection import utils, metric
from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box
from crf import densecrf
#from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners
# DINO hyperparameters
vit_arch = 'base'
vit_feat = 'k'
patch_size = 8
# DINO pre-trained model
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
feat_dim = 768
dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size)
dino_backbone = dino_backbone.eval().requires_grad_(False).cuda()

def get_affinity_matrix(feats, tau, eps=1e-5):
    # get affinity matrix via measuring patch-wise cosine similarity
    feats = F.normalize(feats, p=2, dim=1)
    A = (feats.permute(0, 2, 1) @ feats)
    # convert the affinity matrix to a binary one.
    A = A > tau
    eps = torch.ones_like(A) * eps
    A = torch.where(A.float() == 0, eps, A)
    d_i = A.sum(-1)
    D = torch.diag_embed(d_i)
    return A, D

def second_smallest_eigenvector(A, D):
    # get the second smallest eigenvector from affinity matrix
    _, eigenvectors = torch.lobpcg(D - A, B=D, k=2, largest=False)
    second_smallest_vec = eigenvectors[:, :, 1]
    return -second_smallest_vec

def get_salient_areas(second_smallest_vec):
    # get the area corresponding to salient objects.
    avg = second_smallest_vec.mean(-1, keepdims=True)
    bipartition = second_smallest_vec > avg
    return bipartition

def check_num_fg_corners(bipartition, dims):
    # check number of corners belonging to the foreground
    dims = [bipartition.shape[0]] + dims
    bipartition_ = bipartition.reshape(dims)
    top_l, top_r, bottom_l, bottom_r = bipartition_[:,0,0], bipartition_[:,0,-1], bipartition_[:,-1,0], bipartition_[:,-1,-1]
    nc = top_l.int() + top_r.int() + bottom_l.int() + bottom_r.int()
    return nc

def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]):
    input_dino = images
    input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device)
    input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device)
    # input_dino = images.tensor
    input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear')
    feats = dino_backbone(input_dino) # [B, C, N]
    B = feats.shape[0]

    predominence_map = []
    if current_mask == None:
        painting = torch.from_numpy(np.zeros(dims))
        painting = painting.to(feats)
    else:
        feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0])



    A, D = get_affinity_matrix(feats, tau=0.15)
    # get the second-smallest eigenvector


    #_second_smallest_vec = maskcut.second_smallest_eigenvector(A[10].cpu(), D[10].cpu())
    second_smallest_vec = second_smallest_eigenvector(A, D)

    # get salient area
    bipartition = get_salient_areas(second_smallest_vec)

    # check if we should reverse the partition based on:
    # 1) peak of the 2nd smallest eigvec 2) object centric bias
    batch_inds = torch.arange(second_smallest_vec.shape[0]).to(second_smallest_vec).unsqueeze(0)
    seed = torch.argmax(second_smallest_vec.abs(), dim=-1).unsqueeze(0)
    seed = torch.cat([batch_inds, seed], dim=0).long()

    reverse = bipartition[list(seed)] !=1

    nc = check_num_fg_corners(bipartition, dims)
    reverse[nc >= 2] = True
    second_smallest_vec[reverse] = 1 - second_smallest_vec[reverse]

    second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous()
    map = torch.nn.functional.interpolate(second_smallest_vec.reshape(B, 1, dims[0], dims[1]), size=img_size,
                                          mode='bicubic')
    map -= map.min()
    map /= map.max()
    predominence_map.append(map)
    init_dist = torch.cat(predominence_map, dim=0).detach().contiguous()

    return init_dist, A, feats, painting




def interpolate_pos_encoding(pos_embed, n_frames, h, w):
    N = pos_embed.shape[1]
    if N == (h * w * n_frames):
        return pos_embed
    old_h = old_w = int((N / n_frames) ** 0.5)
    patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2)

    patch_pos_embed = F.interpolate(
        patch_pos_embed,
        size=(h, w),
        mode='bicubic',
    )
    return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)



def vis_results(x, targets_dict, predominance, annotation, name):
    B = x.shape[0]

    fig, axs = plt.subplots(B, 2+len(targets_dict), figsize=(3*len(targets_dict), 2*B))

    for b in range(B):
        img = x[b, 0].permute(1, 2, 0).cpu()
        axs[b, 0].imshow(img)
        axs[b, 0].set_title('Image')
        axs[b, 1].imshow(predominance[b, 0].cpu())
        axs[b, 1].set_title('Predominace')

        for i, v in enumerate(targets_list):
            v = v[b, 0]  # .cpu()
            axs[b, 1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img)))
            axs[b, 1+i].set_title(f'Segment {i}', fontsize=10)

    for ax in axs:
        for a in ax:
            a.set_axis_off()

    plt.show()
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False)
    parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images')
    parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results')
    parser.add_argument('--num_iter', default=1, type=int, help='number of iterations')
    parser.add_argument('--visualize', action='store_true', help='Visualize the results')
    args = parser.parse_args()

    ## Prepare for the extraction
    image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern
    thresh = 0.5
    visualize = args.visualize
    save_dict = {}
    image_size = [480, 480]
    patch_size = 8
    dims = [int(s / patch_size) for s in image_size]
    batch_size = 10

    ## Load pretrained model
    default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/'
    model_func = vmae_transformers.vitb_8x8patch_3frames
    ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400'  # the original IMU-conditioned 4x4
    label = '3 frame 8x8'
    teacher_func = teachers.iteration_segment_teacher_with_filter

    teacher = teacher_func(
        model_func=model_func,
        model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1),
        visualization_mode=visualize,
        initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1},
    ).requires_grad_(False).cuda()

    teacher.predictor.encoder.pos_embed = interpolate_pos_encoding(
        teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1])
    teacher.predictor.pos_embed = interpolate_pos_encoding(
        teacher.predictor.pos_embed, 3, dims[0], dims[1])
    teacher.predictor.image_size = image_size

    ## Start extracting segments
    start = time.time()
    batch = []
    image_names = []
    import pdb;pdb.set_trace()
    for image_path in sorted(image_list):

        # Prepare input
        image_name = image_path.split('/')[-1]
        image = read_image(image_path)
        if image.shape[0] == 1:
            image = image.expand(3, -1, -1)

        x = torch.stack([image] * 3, dim=0)
        x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255.

        print('length', len(batch))
        if len(batch) < batch_size:
            batch.append(x)
            image_names.append(image_name)
            continue
        else:
            x = torch.cat(batch, dim=0)
            batch = []
            image_names = []

        _x = x.to(torch.float16).cuda()

        targets_list = []
        # extract segments iteratively
        for n in range(args.num_iter):

            # Compute predominance map from dino
            if n == 0:
                predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size)
            else:
                raise ValueError('Not implemented')
                predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(),
                                                                         current_mask=current_mask.cuda(),
                                                                         painting=painting, dims=dims,
                                                                         img_size=image_size)

            # mask out segments that are already extracted
            if n > 0:
                for mask in targets_list:
                    predominance[0, 0][mask[0, 0].cuda()] = 0


            # extract segments given predominance map
            with torch.cuda.amp.autocast(enabled=True):
                targets = teacher(_x, sampling_distribution=predominance)[0]
                print('targets.shape', targets.shape)
                if n == 0:
                    targets_list = [targets.cpu() >= thresh]
                else:
                    ratio = targets.mean()
                    mask = targets.cpu() >= thresh
                    iou = 0
                    match_idx = None

                    for idx, existing_mask in enumerate(targets_list):
                        _iou = metric.IoU(mask[0, 0], existing_mask[0, 0])
                        if _iou > iou:
                            iou = _iou
                            match_idx = idx

                    # remove segments if it has large IoU
                    if iou > 0.2 or ratio <= 0.01:
                        mask = torch.zeros_like(mask)
                    # elif iou > 0.1:
                    #     mask[0, 0][targets_list[match_idx][0, 0]] = 0

                    targets_list.append(mask)

                current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh

            vid_name = image_path
            save_dict[image_name] = targets_list
            if visualize:
                vis_results(x, targets_list, predominance, None, vid_name.split('/')[-2] + '.png')

            if (len(save_dict) + 1) % 1 == 0:
                total = len(image_list)
                num_completed = len(save_dict)
                avg_time = (time.time() - start) / num_completed
                eta = (total - num_completed) * avg_time / 60.
                print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins')
                print('remove save')
                #torch.save(save_dict, args.output)
    ## Save the results
    torch.save(save_dict, args.output)