counterfactual-world-models / cwm /eval /Segmentation /archive /generate_zero_shot_segments.py
rahulvenkk
app.py updated
6dfcb0f
raw
history blame
10.2 kB
import torch
import os
import glob
import time
from torchvision.io import read_image
import matplotlib.pyplot as plt
from scipy import ndimage
from PIL import Image
import bbnet.trainval.validator as validator
import modeling_pretrain_cleaned as vmae_transformers
import modeling_pretrain as vmae_transformers_old
import positional_vmae as pos_transformers
import big_models as big_transformers
import bbnet.models.preprocessor as preprocessor
import bbnet.models.error as error_generator
from functools import partial
import bbnet.models.teachers as teachers
from tqdm import tqdm
from torch.nn import functional as F
import argparse
import sys
import numpy as np
import json
import pycocotools.mask as mask_util
sys.path.append('/ccn2/u/honglinc/CutLER')
sys.path.append('/ccn2/u/honglinc/CutLER/maskcut')
sys.path.append('/ccn2/u/honglinc/CutLER/third_party')
import dino
from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix
from third_party.TokenCut.unsupervised_saliency_detection import utils, metric
from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box
from crf import densecrf
#from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners
# DINO hyperparameters
vit_arch = 'base'
vit_feat = 'k'
patch_size = 8
# DINO pre-trained model
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
feat_dim = 768
dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size)
dino_backbone = dino_backbone.eval().requires_grad_(False).cuda()
def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]):
input_dino = images
input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device)
input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device)
# input_dino = images.tensor
input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear')
features = dino_backbone(input_dino)
predominence_map = []
for i in range(features.shape[0]):
feats = features[i]
if current_mask == None:
painting = torch.from_numpy(np.zeros(dims))
painting = painting.to(feats)
else:
feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0])
A, D = get_affinity_matrix(feats, tau=0.15)
# get the second-smallest eigenvector
_, second_smallest_vec = second_smallest_eigenvector(A, D)
# get salient area
bipartition = get_salient_areas(second_smallest_vec)
# check if we should reverse the partition based on:
# 1) peak of the 2nd smallest eigvec 2) object centric bias
seed = np.argmax(np.abs(second_smallest_vec))
nc = check_num_fg_corners(bipartition, dims)
if nc >= 2:
reverse = True
else:
reverse = bipartition[seed] != 1
if reverse:
second_smallest_vec = 1 - second_smallest_vec
second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous()
map = torch.nn.functional.interpolate(second_smallest_vec.reshape(1, 1, dims[0], dims[1]), size=img_size,
mode='bilinear')
map -= map.min()
map /= map.max()
predominence_map.append(map)
init_dist = torch.cat(predominence_map, dim=0).detach()
return init_dist, A, feats, painting
def interpolate_pos_encoding(pos_embed, n_frames, h, w):
N = pos_embed.shape[1]
if N == (h * w * n_frames):
return pos_embed
old_h = old_w = int((N / n_frames) ** 0.5)
patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
size=(h, w),
mode='bicubic',
)
return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0)
def vis_results(x, targets_dict, annotation, name):
img = x[0, 0].permute(1, 2, 0).cpu()
fig, axs = plt.subplots(1, 1+len(targets_dict), figsize=(3*len(targets_dict), 3))
axs[0].imshow(img)
axs[0].set_title('Image')
for i, v in enumerate(targets_list):
v = v[0, 0] # .cpu()
axs[1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img)))
axs[1+i].set_title(f'Segment {i}', fontsize=10)
for ax in axs:
ax.set_axis_off()
plt.show()
plt.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False)
parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images')
parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results')
parser.add_argument('--num_iter', default=1, type=int, help='number of iterations')
parser.add_argument('--visualize', action='store_true', help='Visualize the results')
args = parser.parse_args()
## Prepare for the extraction
image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern
thresh = 0.5
visualize = args.visualize
save_dict = {}
image_size = [480, 480]
patch_size = 8
dims = [int(s / patch_size) for s in image_size]
## Load pretrained model
default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/'
model_func = vmae_transformers.vitb_8x8patch_3frames
ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400' # the original IMU-conditioned 4x4
label = '3 frame 8x8'
teacher_func = teachers.iteration_segment_teacher_with_filter
teacher = teacher_func(
model_func=model_func,
model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1),
visualization_mode=visualize,
initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1},
).requires_grad_(False).cuda()
teacher.predictor.encoder.pos_embed = interpolate_pos_encoding(
teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1])
teacher.predictor.pos_embed = interpolate_pos_encoding(
teacher.predictor.pos_embed, 3, dims[0], dims[1])
teacher.predictor.image_size = image_size
## Start extracting segments
start = time.time()
if os.path.exists(args.output):
print('Load partial results from: ', args.output)
save_dict = torch.load(args.output)
print('Length of existing dict: ', len(save_dict))
for image_path in image_list:
# Prepare input
image_name = image_path.split('/')[-1]
if image_name in save_dict:
continue
image = read_image(image_path)
if image.shape[0] == 1:
image = image.expand(3, -1, -1)
x = torch.stack([image] * 3, dim=0)
x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255.
_x = x.to(torch.float16).cuda()
targets_list = []
# extract segments iteratively
for n in range(args.num_iter):
# Compute predominance map from dino
if n == 0:
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size)
else:
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(),
current_mask=current_mask.cuda(),
painting=painting, dims=dims,
img_size=image_size)
if visualize:
plt.imshow(predominance[0, 0].cpu())
plt.title(f'Predominance (max:{predominance[0, 0].max()})')
plt.show()
# mask out segments that are already extracted
if n > 0:
for mask in targets_list:
predominance[0, 0][mask[0, 0].cuda()] = 0
# extract segments given predominance map
with torch.cuda.amp.autocast(enabled=True):
targets = teacher(_x, sampling_distribution=predominance)[0]
if n == 0:
targets_list = [targets.cpu() >= thresh]
else:
ratio = targets.mean()
mask = targets.cpu() >= thresh
iou = 0
match_idx = None
for idx, existing_mask in enumerate(targets_list):
_iou = metric.IoU(mask[0, 0], existing_mask[0, 0])
if _iou > iou:
iou = _iou
match_idx = idx
# remove segments if it has large IoU
if iou > 0.2 or ratio <= 0.01:
mask = torch.zeros_like(mask)
# elif iou > 0.1:
# mask[0, 0][targets_list[match_idx][0, 0]] = 0
targets_list.append(mask)
current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh
vid_name = image_path
save_dict[image_name] = targets_list
if visualize:
vis_results(x, targets_list, None, vid_name.split('/')[-2] + '.png')
if (len(save_dict) + 1) % 1 == 0:
total = len(image_list)
num_completed = len(save_dict)
avg_time = (time.time() - start) / num_completed
eta = (total - num_completed) * avg_time / 60.
print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins')
torch.save(save_dict, args.output)
## Save the results
torch.save(save_dict, args.output)