import torch import os import glob import time from torchvision.io import read_image import matplotlib.pyplot as plt from scipy import ndimage from PIL import Image #import bbnet.trainval.validator as validator import modeling_pretrain_cleaned as vmae_transformers import modeling_pretrain as vmae_transformers_old #import positional_vmae as pos_transformers #import big_models as big_transformers import bbnet.models.preprocessor as preprocessor import bbnet.models.error as error_generator from functools import partial #import bbnet.models.teachers as teachers from tqdm import tqdm from torch.nn import functional as F import argparse import sys import numpy as np import json import pycocotools.mask as mask_util sys.path.append('/ccn2/u/honglinc/CutLER') sys.path.append('/ccn2/u/honglinc/CutLER/maskcut') sys.path.append('/ccn2/u/honglinc/CutLER/third_party') import dino import maskcut # from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix from third_party.TokenCut.unsupervised_saliency_detection import utils, metric from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box from crf import densecrf #from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners # DINO hyperparameters vit_arch = 'base' vit_feat = 'k' patch_size = 8 # DINO pre-trained model url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" feat_dim = 768 dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size) dino_backbone = dino_backbone.eval().requires_grad_(False).cuda() def get_affinity_matrix(feats, tau, eps=1e-5): # get affinity matrix via measuring patch-wise cosine similarity feats = F.normalize(feats, p=2, dim=1) A = (feats.permute(0, 2, 1) @ feats) # convert the affinity matrix to a binary one. A = A > tau eps = torch.ones_like(A) * eps A = torch.where(A.float() == 0, eps, A) d_i = A.sum(-1) D = torch.diag_embed(d_i) return A, D def second_smallest_eigenvector(A, D): # get the second smallest eigenvector from affinity matrix _, eigenvectors = torch.lobpcg(D - A, B=D, k=2, largest=False) second_smallest_vec = eigenvectors[:, :, 1] return -second_smallest_vec def get_salient_areas(second_smallest_vec): # get the area corresponding to salient objects. avg = second_smallest_vec.mean(-1, keepdims=True) bipartition = second_smallest_vec > avg return bipartition def check_num_fg_corners(bipartition, dims): # check number of corners belonging to the foreground dims = [bipartition.shape[0]] + dims bipartition_ = bipartition.reshape(dims) top_l, top_r, bottom_l, bottom_r = bipartition_[:,0,0], bipartition_[:,0,-1], bipartition_[:,-1,0], bipartition_[:,-1,-1] nc = top_l.int() + top_r.int() + bottom_l.int() + bottom_r.int() return nc def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]): input_dino = images input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device) input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device) # input_dino = images.tensor input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear') feats = dino_backbone(input_dino) # [B, C, N] B = feats.shape[0] predominence_map = [] if current_mask == None: painting = torch.from_numpy(np.zeros(dims)) painting = painting.to(feats) else: feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0]) A, D = get_affinity_matrix(feats, tau=0.15) # get the second-smallest eigenvector #_second_smallest_vec = maskcut.second_smallest_eigenvector(A[10].cpu(), D[10].cpu()) second_smallest_vec = second_smallest_eigenvector(A, D) # get salient area bipartition = get_salient_areas(second_smallest_vec) # check if we should reverse the partition based on: # 1) peak of the 2nd smallest eigvec 2) object centric bias batch_inds = torch.arange(second_smallest_vec.shape[0]).to(second_smallest_vec).unsqueeze(0) seed = torch.argmax(second_smallest_vec.abs(), dim=-1).unsqueeze(0) seed = torch.cat([batch_inds, seed], dim=0).long() reverse = bipartition[list(seed)] !=1 nc = check_num_fg_corners(bipartition, dims) reverse[nc >= 2] = True second_smallest_vec[reverse] = 1 - second_smallest_vec[reverse] second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous() map = torch.nn.functional.interpolate(second_smallest_vec.reshape(B, 1, dims[0], dims[1]), size=img_size, mode='bicubic') map -= map.min() map /= map.max() predominence_map.append(map) init_dist = torch.cat(predominence_map, dim=0).detach().contiguous() return init_dist, A, feats, painting def interpolate_pos_encoding(pos_embed, n_frames, h, w): N = pos_embed.shape[1] if N == (h * w * n_frames): return pos_embed old_h = old_w = int((N / n_frames) ** 0.5) patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) patch_pos_embed = F.interpolate( patch_pos_embed, size=(h, w), mode='bicubic', ) return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) def vis_results(x, targets_dict, predominance, annotation, name): B = x.shape[0] fig, axs = plt.subplots(B, 2+len(targets_dict), figsize=(3*len(targets_dict), 2*B)) for b in range(B): img = x[b, 0].permute(1, 2, 0).cpu() axs[b, 0].imshow(img) axs[b, 0].set_title('Image') axs[b, 1].imshow(predominance[b, 0].cpu()) axs[b, 1].set_title('Predominace') for i, v in enumerate(targets_list): v = v[b, 0] # .cpu() axs[b, 1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img))) axs[b, 1+i].set_title(f'Segment {i}', fontsize=10) for ax in axs: for a in ax: a.set_axis_off() plt.show() plt.close() if __name__ == "__main__": parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False) parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images') parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results') parser.add_argument('--num_iter', default=1, type=int, help='number of iterations') parser.add_argument('--visualize', action='store_true', help='Visualize the results') args = parser.parse_args() ## Prepare for the extraction image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern thresh = 0.5 visualize = args.visualize save_dict = {} image_size = [480, 480] patch_size = 8 dims = [int(s / patch_size) for s in image_size] batch_size = 10 ## Load pretrained model default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/' model_func = vmae_transformers.vitb_8x8patch_3frames ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400' # the original IMU-conditioned 4x4 label = '3 frame 8x8' teacher_func = teachers.iteration_segment_teacher_with_filter teacher = teacher_func( model_func=model_func, model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1), visualization_mode=visualize, initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1}, ).requires_grad_(False).cuda() teacher.predictor.encoder.pos_embed = interpolate_pos_encoding( teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1]) teacher.predictor.pos_embed = interpolate_pos_encoding( teacher.predictor.pos_embed, 3, dims[0], dims[1]) teacher.predictor.image_size = image_size ## Start extracting segments start = time.time() batch = [] image_names = [] import pdb;pdb.set_trace() for image_path in sorted(image_list): # Prepare input image_name = image_path.split('/')[-1] image = read_image(image_path) if image.shape[0] == 1: image = image.expand(3, -1, -1) x = torch.stack([image] * 3, dim=0) x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255. print('length', len(batch)) if len(batch) < batch_size: batch.append(x) image_names.append(image_name) continue else: x = torch.cat(batch, dim=0) batch = [] image_names = [] _x = x.to(torch.float16).cuda() targets_list = [] # extract segments iteratively for n in range(args.num_iter): # Compute predominance map from dino if n == 0: predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size) else: raise ValueError('Not implemented') predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), current_mask=current_mask.cuda(), painting=painting, dims=dims, img_size=image_size) # mask out segments that are already extracted if n > 0: for mask in targets_list: predominance[0, 0][mask[0, 0].cuda()] = 0 # extract segments given predominance map with torch.cuda.amp.autocast(enabled=True): targets = teacher(_x, sampling_distribution=predominance)[0] print('targets.shape', targets.shape) if n == 0: targets_list = [targets.cpu() >= thresh] else: ratio = targets.mean() mask = targets.cpu() >= thresh iou = 0 match_idx = None for idx, existing_mask in enumerate(targets_list): _iou = metric.IoU(mask[0, 0], existing_mask[0, 0]) if _iou > iou: iou = _iou match_idx = idx # remove segments if it has large IoU if iou > 0.2 or ratio <= 0.01: mask = torch.zeros_like(mask) # elif iou > 0.1: # mask[0, 0][targets_list[match_idx][0, 0]] = 0 targets_list.append(mask) current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh vid_name = image_path save_dict[image_name] = targets_list if visualize: vis_results(x, targets_list, predominance, None, vid_name.split('/')[-2] + '.png') if (len(save_dict) + 1) % 1 == 0: total = len(image_list) num_completed = len(save_dict) avg_time = (time.time() - start) / num_completed eta = (total - num_completed) * avg_time / 60. print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins') print('remove save') #torch.save(save_dict, args.output) ## Save the results torch.save(save_dict, args.output)