Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import torch | |
import os | |
import glob | |
import time | |
from torchvision.io import read_image | |
import matplotlib.pyplot as plt | |
from scipy import ndimage | |
from PIL import Image | |
#import bbnet.trainval.validator as validator | |
import modeling_pretrain_cleaned as vmae_transformers | |
import modeling_pretrain as vmae_transformers_old | |
#import positional_vmae as pos_transformers | |
#import big_models as big_transformers | |
import bbnet.models.preprocessor as preprocessor | |
import bbnet.models.error as error_generator | |
from functools import partial | |
#import bbnet.models.teachers as teachers | |
from tqdm import tqdm | |
from torch.nn import functional as F | |
import argparse | |
import sys | |
import numpy as np | |
import json | |
import pycocotools.mask as mask_util | |
sys.path.append('/ccn2/u/honglinc/CutLER') | |
sys.path.append('/ccn2/u/honglinc/CutLER/maskcut') | |
sys.path.append('/ccn2/u/honglinc/CutLER/third_party') | |
import dino | |
import maskcut | |
# from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix | |
from third_party.TokenCut.unsupervised_saliency_detection import utils, metric | |
from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box | |
from crf import densecrf | |
#from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners | |
# DINO hyperparameters | |
vit_arch = 'base' | |
vit_feat = 'k' | |
patch_size = 8 | |
# DINO pre-trained model | |
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" | |
feat_dim = 768 | |
dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size) | |
dino_backbone = dino_backbone.eval().requires_grad_(False).cuda() | |
def get_affinity_matrix(feats, tau, eps=1e-5): | |
# get affinity matrix via measuring patch-wise cosine similarity | |
feats = F.normalize(feats, p=2, dim=1) | |
A = (feats.permute(0, 2, 1) @ feats) | |
# convert the affinity matrix to a binary one. | |
A = A > tau | |
eps = torch.ones_like(A) * eps | |
A = torch.where(A.float() == 0, eps, A) | |
d_i = A.sum(-1) | |
D = torch.diag_embed(d_i) | |
return A, D | |
def second_smallest_eigenvector(A, D): | |
# get the second smallest eigenvector from affinity matrix | |
_, eigenvectors = torch.lobpcg(D - A, B=D, k=2, largest=False) | |
second_smallest_vec = eigenvectors[:, :, 1] | |
return -second_smallest_vec | |
def get_salient_areas(second_smallest_vec): | |
# get the area corresponding to salient objects. | |
avg = second_smallest_vec.mean(-1, keepdims=True) | |
bipartition = second_smallest_vec > avg | |
return bipartition | |
def check_num_fg_corners(bipartition, dims): | |
# check number of corners belonging to the foreground | |
dims = [bipartition.shape[0]] + dims | |
bipartition_ = bipartition.reshape(dims) | |
top_l, top_r, bottom_l, bottom_r = bipartition_[:,0,0], bipartition_[:,0,-1], bipartition_[:,-1,0], bipartition_[:,-1,-1] | |
nc = top_l.int() + top_r.int() + bottom_l.int() + bottom_r.int() | |
return nc | |
def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]): | |
input_dino = images | |
input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device) | |
input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device) | |
# input_dino = images.tensor | |
input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear') | |
feats = dino_backbone(input_dino) # [B, C, N] | |
B = feats.shape[0] | |
predominence_map = [] | |
if current_mask == None: | |
painting = torch.from_numpy(np.zeros(dims)) | |
painting = painting.to(feats) | |
else: | |
feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0]) | |
A, D = get_affinity_matrix(feats, tau=0.15) | |
# get the second-smallest eigenvector | |
#_second_smallest_vec = maskcut.second_smallest_eigenvector(A[10].cpu(), D[10].cpu()) | |
second_smallest_vec = second_smallest_eigenvector(A, D) | |
# get salient area | |
bipartition = get_salient_areas(second_smallest_vec) | |
# check if we should reverse the partition based on: | |
# 1) peak of the 2nd smallest eigvec 2) object centric bias | |
batch_inds = torch.arange(second_smallest_vec.shape[0]).to(second_smallest_vec).unsqueeze(0) | |
seed = torch.argmax(second_smallest_vec.abs(), dim=-1).unsqueeze(0) | |
seed = torch.cat([batch_inds, seed], dim=0).long() | |
reverse = bipartition[list(seed)] !=1 | |
nc = check_num_fg_corners(bipartition, dims) | |
reverse[nc >= 2] = True | |
second_smallest_vec[reverse] = 1 - second_smallest_vec[reverse] | |
second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous() | |
map = torch.nn.functional.interpolate(second_smallest_vec.reshape(B, 1, dims[0], dims[1]), size=img_size, | |
mode='bicubic') | |
map -= map.min() | |
map /= map.max() | |
predominence_map.append(map) | |
init_dist = torch.cat(predominence_map, dim=0).detach().contiguous() | |
return init_dist, A, feats, painting | |
def interpolate_pos_encoding(pos_embed, n_frames, h, w): | |
N = pos_embed.shape[1] | |
if N == (h * w * n_frames): | |
return pos_embed | |
old_h = old_w = int((N / n_frames) ** 0.5) | |
patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) | |
patch_pos_embed = F.interpolate( | |
patch_pos_embed, | |
size=(h, w), | |
mode='bicubic', | |
) | |
return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) | |
def vis_results(x, targets_dict, predominance, annotation, name): | |
B = x.shape[0] | |
fig, axs = plt.subplots(B, 2+len(targets_dict), figsize=(3*len(targets_dict), 2*B)) | |
for b in range(B): | |
img = x[b, 0].permute(1, 2, 0).cpu() | |
axs[b, 0].imshow(img) | |
axs[b, 0].set_title('Image') | |
axs[b, 1].imshow(predominance[b, 0].cpu()) | |
axs[b, 1].set_title('Predominace') | |
for i, v in enumerate(targets_list): | |
v = v[b, 0] # .cpu() | |
axs[b, 1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img))) | |
axs[b, 1+i].set_title(f'Segment {i}', fontsize=10) | |
for ax in axs: | |
for a in ax: | |
a.set_axis_off() | |
plt.show() | |
plt.close() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False) | |
parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images') | |
parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results') | |
parser.add_argument('--num_iter', default=1, type=int, help='number of iterations') | |
parser.add_argument('--visualize', action='store_true', help='Visualize the results') | |
args = parser.parse_args() | |
## Prepare for the extraction | |
image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern | |
thresh = 0.5 | |
visualize = args.visualize | |
save_dict = {} | |
image_size = [480, 480] | |
patch_size = 8 | |
dims = [int(s / patch_size) for s in image_size] | |
batch_size = 10 | |
## Load pretrained model | |
default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/' | |
model_func = vmae_transformers.vitb_8x8patch_3frames | |
ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400' # the original IMU-conditioned 4x4 | |
label = '3 frame 8x8' | |
teacher_func = teachers.iteration_segment_teacher_with_filter | |
teacher = teacher_func( | |
model_func=model_func, | |
model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1), | |
visualization_mode=visualize, | |
initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1}, | |
).requires_grad_(False).cuda() | |
teacher.predictor.encoder.pos_embed = interpolate_pos_encoding( | |
teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1]) | |
teacher.predictor.pos_embed = interpolate_pos_encoding( | |
teacher.predictor.pos_embed, 3, dims[0], dims[1]) | |
teacher.predictor.image_size = image_size | |
## Start extracting segments | |
start = time.time() | |
batch = [] | |
image_names = [] | |
import pdb;pdb.set_trace() | |
for image_path in sorted(image_list): | |
# Prepare input | |
image_name = image_path.split('/')[-1] | |
image = read_image(image_path) | |
if image.shape[0] == 1: | |
image = image.expand(3, -1, -1) | |
x = torch.stack([image] * 3, dim=0) | |
x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255. | |
print('length', len(batch)) | |
if len(batch) < batch_size: | |
batch.append(x) | |
image_names.append(image_name) | |
continue | |
else: | |
x = torch.cat(batch, dim=0) | |
batch = [] | |
image_names = [] | |
_x = x.to(torch.float16).cuda() | |
targets_list = [] | |
# extract segments iteratively | |
for n in range(args.num_iter): | |
# Compute predominance map from dino | |
if n == 0: | |
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size) | |
else: | |
raise ValueError('Not implemented') | |
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), | |
current_mask=current_mask.cuda(), | |
painting=painting, dims=dims, | |
img_size=image_size) | |
# mask out segments that are already extracted | |
if n > 0: | |
for mask in targets_list: | |
predominance[0, 0][mask[0, 0].cuda()] = 0 | |
# extract segments given predominance map | |
with torch.cuda.amp.autocast(enabled=True): | |
targets = teacher(_x, sampling_distribution=predominance)[0] | |
print('targets.shape', targets.shape) | |
if n == 0: | |
targets_list = [targets.cpu() >= thresh] | |
else: | |
ratio = targets.mean() | |
mask = targets.cpu() >= thresh | |
iou = 0 | |
match_idx = None | |
for idx, existing_mask in enumerate(targets_list): | |
_iou = metric.IoU(mask[0, 0], existing_mask[0, 0]) | |
if _iou > iou: | |
iou = _iou | |
match_idx = idx | |
# remove segments if it has large IoU | |
if iou > 0.2 or ratio <= 0.01: | |
mask = torch.zeros_like(mask) | |
# elif iou > 0.1: | |
# mask[0, 0][targets_list[match_idx][0, 0]] = 0 | |
targets_list.append(mask) | |
current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh | |
vid_name = image_path | |
save_dict[image_name] = targets_list | |
if visualize: | |
vis_results(x, targets_list, predominance, None, vid_name.split('/')[-2] + '.png') | |
if (len(save_dict) + 1) % 1 == 0: | |
total = len(image_list) | |
num_completed = len(save_dict) | |
avg_time = (time.time() - start) / num_completed | |
eta = (total - num_completed) * avg_time / 60. | |
print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins') | |
print('remove save') | |
#torch.save(save_dict, args.output) | |
## Save the results | |
torch.save(save_dict, args.output) |