counterfactual-world-models / cwm /eval /Segmentation /archive /generate_zero_shot_segments_v2.py
rahulvenkk
app.py updated
6dfcb0f
raw
history blame
11.9 kB
import torch
import os
import glob
import time
from torchvision.io import read_image
import matplotlib.pyplot as plt
from scipy import ndimage
from PIL import Image
#import bbnet.trainval.validator as validator
import modeling_pretrain_cleaned as vmae_transformers
import modeling_pretrain as vmae_transformers_old
#import positional_vmae as pos_transformers
#import big_models as big_transformers
import bbnet.models.preprocessor as preprocessor
import bbnet.models.error as error_generator
from functools import partial
#import bbnet.models.teachers as teachers
from tqdm import tqdm
from torch.nn import functional as F
import argparse
import sys
import numpy as np
import json
import pycocotools.mask as mask_util
sys.path.append('/ccn2/u/honglinc/CutLER')
sys.path.append('/ccn2/u/honglinc/CutLER/maskcut')
sys.path.append('/ccn2/u/honglinc/CutLER/third_party')
import dino
import maskcut
# from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix
from third_party.TokenCut.unsupervised_saliency_detection import utils, metric
from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box
from crf import densecrf
#from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners
# DINO hyperparameters
vit_arch = 'base'
vit_feat = 'k'
patch_size = 8
# DINO pre-trained model
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
feat_dim = 768
dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size)
dino_backbone = dino_backbone.eval().requires_grad_(False).cuda()
def get_affinity_matrix(feats, tau, eps=1e-5):
# get affinity matrix via measuring patch-wise cosine similarity
feats = F.normalize(feats, p=2, dim=1)
A = (feats.permute(0, 2, 1) @ feats)
# convert the affinity matrix to a binary one.
A = A > tau
eps = torch.ones_like(A) * eps
A = torch.where(A.float() == 0, eps, A)
d_i = A.sum(-1)
D = torch.diag_embed(d_i)
return A, D
def second_smallest_eigenvector(A, D):
# get the second smallest eigenvector from affinity matrix
_, eigenvectors = torch.lobpcg(D - A, B=D, k=2, largest=False)
second_smallest_vec = eigenvectors[:, :, 1]
return -second_smallest_vec
def get_salient_areas(second_smallest_vec):
# get the area corresponding to salient objects.
avg = second_smallest_vec.mean(-1, keepdims=True)
bipartition = second_smallest_vec > avg
return bipartition
def check_num_fg_corners(bipartition, dims):
# check number of corners belonging to the foreground
dims = [bipartition.shape[0]] + dims
bipartition_ = bipartition.reshape(dims)
top_l, top_r, bottom_l, bottom_r = bipartition_[:,0,0], bipartition_[:,0,-1], bipartition_[:,-1,0], bipartition_[:,-1,-1]
nc = top_l.int() + top_r.int() + bottom_l.int() + bottom_r.int()
return nc
def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]):
input_dino = images
input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device)
input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device)
# input_dino = images.tensor
input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear')
feats = dino_backbone(input_dino) # [B, C, N]
B = feats.shape[0]
predominence_map = []
if current_mask == None:
painting = torch.from_numpy(np.zeros(dims))
painting = painting.to(feats)
else:
feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0])
A, D = get_affinity_matrix(feats, tau=0.15)
# get the second-smallest eigenvector
#_second_smallest_vec = maskcut.second_smallest_eigenvector(A[10].cpu(), D[10].cpu())
second_smallest_vec = second_smallest_eigenvector(A, D)
# get salient area
bipartition = get_salient_areas(second_smallest_vec)
# check if we should reverse the partition based on:
# 1) peak of the 2nd smallest eigvec 2) object centric bias
batch_inds = torch.arange(second_smallest_vec.shape[0]).to(second_smallest_vec).unsqueeze(0)
seed = torch.argmax(second_smallest_vec.abs(), dim=-1).unsqueeze(0)
seed = torch.cat([batch_inds, seed], dim=0).long()
reverse = bipartition[list(seed)] !=1
nc = check_num_fg_corners(bipartition, dims)
reverse[nc >= 2] = True
second_smallest_vec[reverse] = 1 - second_smallest_vec[reverse]
second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous()
map = torch.nn.functional.interpolate(second_smallest_vec.reshape(B, 1, dims[0], dims[1]), size=img_size,
mode='bicubic')
map -= map.min()
map /= map.max()
predominence_map.append(map)
init_dist = torch.cat(predominence_map, dim=0).detach().contiguous()
return init_dist, A, feats, painting
def interpolate_pos_encoding(pos_embed, n_frames, h, w):
N = pos_embed.shape[1]
if N == (h * w * n_frames):
return pos_embed
old_h = old_w = int((N / n_frames) ** 0.5)
patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
size=(h, w),
mode='bicubic',
)
return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
def vis_results(x, targets_dict, predominance, annotation, name):
B = x.shape[0]
fig, axs = plt.subplots(B, 2+len(targets_dict), figsize=(3*len(targets_dict), 2*B))
for b in range(B):
img = x[b, 0].permute(1, 2, 0).cpu()
axs[b, 0].imshow(img)
axs[b, 0].set_title('Image')
axs[b, 1].imshow(predominance[b, 0].cpu())
axs[b, 1].set_title('Predominace')
for i, v in enumerate(targets_list):
v = v[b, 0] # .cpu()
axs[b, 1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img)))
axs[b, 1+i].set_title(f'Segment {i}', fontsize=10)
for ax in axs:
for a in ax:
a.set_axis_off()
plt.show()
plt.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False)
parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images')
parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results')
parser.add_argument('--num_iter', default=1, type=int, help='number of iterations')
parser.add_argument('--visualize', action='store_true', help='Visualize the results')
args = parser.parse_args()
## Prepare for the extraction
image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern
thresh = 0.5
visualize = args.visualize
save_dict = {}
image_size = [480, 480]
patch_size = 8
dims = [int(s / patch_size) for s in image_size]
batch_size = 10
## Load pretrained model
default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/'
model_func = vmae_transformers.vitb_8x8patch_3frames
ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400' # the original IMU-conditioned 4x4
label = '3 frame 8x8'
teacher_func = teachers.iteration_segment_teacher_with_filter
teacher = teacher_func(
model_func=model_func,
model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1),
visualization_mode=visualize,
initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1},
).requires_grad_(False).cuda()
teacher.predictor.encoder.pos_embed = interpolate_pos_encoding(
teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1])
teacher.predictor.pos_embed = interpolate_pos_encoding(
teacher.predictor.pos_embed, 3, dims[0], dims[1])
teacher.predictor.image_size = image_size
## Start extracting segments
start = time.time()
batch = []
image_names = []
import pdb;pdb.set_trace()
for image_path in sorted(image_list):
# Prepare input
image_name = image_path.split('/')[-1]
image = read_image(image_path)
if image.shape[0] == 1:
image = image.expand(3, -1, -1)
x = torch.stack([image] * 3, dim=0)
x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255.
print('length', len(batch))
if len(batch) < batch_size:
batch.append(x)
image_names.append(image_name)
continue
else:
x = torch.cat(batch, dim=0)
batch = []
image_names = []
_x = x.to(torch.float16).cuda()
targets_list = []
# extract segments iteratively
for n in range(args.num_iter):
# Compute predominance map from dino
if n == 0:
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size)
else:
raise ValueError('Not implemented')
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(),
current_mask=current_mask.cuda(),
painting=painting, dims=dims,
img_size=image_size)
# mask out segments that are already extracted
if n > 0:
for mask in targets_list:
predominance[0, 0][mask[0, 0].cuda()] = 0
# extract segments given predominance map
with torch.cuda.amp.autocast(enabled=True):
targets = teacher(_x, sampling_distribution=predominance)[0]
print('targets.shape', targets.shape)
if n == 0:
targets_list = [targets.cpu() >= thresh]
else:
ratio = targets.mean()
mask = targets.cpu() >= thresh
iou = 0
match_idx = None
for idx, existing_mask in enumerate(targets_list):
_iou = metric.IoU(mask[0, 0], existing_mask[0, 0])
if _iou > iou:
iou = _iou
match_idx = idx
# remove segments if it has large IoU
if iou > 0.2 or ratio <= 0.01:
mask = torch.zeros_like(mask)
# elif iou > 0.1:
# mask[0, 0][targets_list[match_idx][0, 0]] = 0
targets_list.append(mask)
current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh
vid_name = image_path
save_dict[image_name] = targets_list
if visualize:
vis_results(x, targets_list, predominance, None, vid_name.split('/')[-2] + '.png')
if (len(save_dict) + 1) % 1 == 0:
total = len(image_list)
num_completed = len(save_dict)
avg_time = (time.time() - start) / num_completed
eta = (total - num_completed) * avg_time / 60.
print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins')
print('remove save')
#torch.save(save_dict, args.output)
## Save the results
torch.save(save_dict, args.output)