File size: 10,224 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import torch
import os
import glob
import time
from torchvision.io import read_image
import matplotlib.pyplot as plt
from scipy import ndimage
from PIL import Image
import bbnet.trainval.validator as validator
import modeling_pretrain_cleaned as vmae_transformers
import modeling_pretrain as vmae_transformers_old
import positional_vmae as pos_transformers
import big_models as big_transformers
import bbnet.models.preprocessor as preprocessor
import bbnet.models.error as error_generator
from functools import partial
import bbnet.models.teachers as teachers
from tqdm import tqdm
from torch.nn import functional as F
import argparse
import sys
import numpy as np
import json
import pycocotools.mask as mask_util
sys.path.append('/ccn2/u/honglinc/CutLER')
sys.path.append('/ccn2/u/honglinc/CutLER/maskcut')
sys.path.append('/ccn2/u/honglinc/CutLER/third_party')
import dino
from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix
from third_party.TokenCut.unsupervised_saliency_detection import utils, metric
from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box
from crf import densecrf
#from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners
# DINO hyperparameters
vit_arch = 'base'
vit_feat = 'k'
patch_size = 8
# DINO pre-trained model
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
feat_dim = 768
dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size)
dino_backbone = dino_backbone.eval().requires_grad_(False).cuda()


def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]):
    input_dino = images
    input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device)
    input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device)
    # input_dino = images.tensor
    input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear')
    features = dino_backbone(input_dino)

    predominence_map = []

    for i in range(features.shape[0]):
        feats = features[i]
        if current_mask == None:
            painting = torch.from_numpy(np.zeros(dims))
            painting = painting.to(feats)
        else:
            feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0])

        A, D = get_affinity_matrix(feats, tau=0.15)
        # get the second-smallest eigenvector
        _, second_smallest_vec = second_smallest_eigenvector(A, D)
        # get salient area
        bipartition = get_salient_areas(second_smallest_vec)

        # check if we should reverse the partition based on:
        # 1) peak of the 2nd smallest eigvec 2) object centric bias
        seed = np.argmax(np.abs(second_smallest_vec))
        nc = check_num_fg_corners(bipartition, dims)
        if nc >= 2:
            reverse = True
        else:
            reverse = bipartition[seed] != 1
        if reverse:
            second_smallest_vec = 1 - second_smallest_vec
        second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous()
        map = torch.nn.functional.interpolate(second_smallest_vec.reshape(1, 1, dims[0], dims[1]), size=img_size,
                                              mode='bilinear')
        map -= map.min()
        map /= map.max()
        predominence_map.append(map)
    init_dist = torch.cat(predominence_map, dim=0).detach()
    return init_dist, A, feats, painting




def interpolate_pos_encoding(pos_embed, n_frames, h, w):
    N = pos_embed.shape[1]
    if N == (h * w * n_frames):
        return pos_embed
    old_h = old_w = int((N / n_frames) ** 0.5)
    patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2)

    patch_pos_embed = F.interpolate(
        patch_pos_embed,
        size=(h, w),
        mode='bicubic',
    )
    return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)



def vis_results(x, targets_dict, annotation, name):
    img = x[0, 0].permute(1, 2, 0).cpu()
    fig, axs = plt.subplots(1, 1+len(targets_dict), figsize=(3*len(targets_dict), 3))
    axs[0].imshow(img)
    axs[0].set_title('Image')

    for i, v in enumerate(targets_list):
        v = v[0, 0]  # .cpu()
        axs[1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img)))
        axs[1+i].set_title(f'Segment {i}', fontsize=10)

    for ax in axs:
        ax.set_axis_off()

    plt.show()
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False)
    parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images')
    parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results')
    parser.add_argument('--num_iter', default=1, type=int, help='number of iterations')
    parser.add_argument('--visualize', action='store_true', help='Visualize the results')
    args = parser.parse_args()

    ## Prepare for the extraction
    image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern
    thresh = 0.5
    visualize = args.visualize
    save_dict = {}
    image_size = [480, 480]
    patch_size = 8
    dims = [int(s / patch_size) for s in image_size]

    ## Load pretrained model
    default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/'
    model_func = vmae_transformers.vitb_8x8patch_3frames
    ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400'  # the original IMU-conditioned 4x4
    label = '3 frame 8x8'
    teacher_func = teachers.iteration_segment_teacher_with_filter

    teacher = teacher_func(
        model_func=model_func,
        model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1),
        visualization_mode=visualize,
        initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1},
    ).requires_grad_(False).cuda()

    teacher.predictor.encoder.pos_embed = interpolate_pos_encoding(
        teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1])
    teacher.predictor.pos_embed = interpolate_pos_encoding(
        teacher.predictor.pos_embed, 3, dims[0], dims[1])
    teacher.predictor.image_size = image_size

    ## Start extracting segments
    start = time.time()


    if os.path.exists(args.output):
        print('Load partial results from: ', args.output)
        save_dict = torch.load(args.output)
        print('Length of existing dict: ', len(save_dict))

    for image_path in image_list:

        # Prepare input
        image_name = image_path.split('/')[-1]

        if image_name in save_dict:
            continue

        image = read_image(image_path)
        if image.shape[0] == 1:
            image = image.expand(3, -1, -1)

        x = torch.stack([image] * 3, dim=0)
        x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255.
        _x = x.to(torch.float16).cuda()

        targets_list = []
        # extract segments iteratively
        for n in range(args.num_iter):

            # Compute predominance map from dino
            if n == 0:
                predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size)
            else:
                predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(),
                                                                         current_mask=current_mask.cuda(),
                                                                         painting=painting, dims=dims,
                                                                         img_size=image_size)

            if visualize:
                plt.imshow(predominance[0, 0].cpu())
                plt.title(f'Predominance (max:{predominance[0, 0].max()})')
                plt.show()

            # mask out segments that are already extracted
            if n > 0:
                for mask in targets_list:
                    predominance[0, 0][mask[0, 0].cuda()] = 0


            # extract segments given predominance map
            with torch.cuda.amp.autocast(enabled=True):
                targets = teacher(_x, sampling_distribution=predominance)[0]
                if n == 0:
                    targets_list = [targets.cpu() >= thresh]
                else:
                    ratio = targets.mean()
                    mask = targets.cpu() >= thresh
                    iou = 0
                    match_idx = None

                    for idx, existing_mask in enumerate(targets_list):
                        _iou = metric.IoU(mask[0, 0], existing_mask[0, 0])
                        if _iou > iou:
                            iou = _iou
                            match_idx = idx

                    # remove segments if it has large IoU
                    if iou > 0.2 or ratio <= 0.01:
                        mask = torch.zeros_like(mask)
                    # elif iou > 0.1:
                    #     mask[0, 0][targets_list[match_idx][0, 0]] = 0

                    targets_list.append(mask)

                current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh

            vid_name = image_path
            save_dict[image_name] = targets_list
            if visualize:
                vis_results(x, targets_list, None, vid_name.split('/')[-2] + '.png')

            if (len(save_dict) + 1) % 1 == 0:
                total = len(image_list)
                num_completed = len(save_dict)
                avg_time = (time.time() - start) / num_completed
                eta = (total - num_completed) * avg_time / 60.
                print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins')
                torch.save(save_dict, args.output)
    ## Save the results
    torch.save(save_dict, args.output)