from einops import rearrange import torch import numpy as np from torchvision import transforms def unpatchify(labels, norm=True): # Define the input tensor B = labels.shape[0] # batch size N_patches = int(np.sqrt(labels.shape[1])) # number of patches along each dimension patch_size = int(np.sqrt(labels.shape[2] / 3)) # patch size along each dimension channels = 3 # number of channels rec_imgs = rearrange(labels, 'b n (p c) -> b n p c', c=3) # Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch. rec_imgs = rearrange(rec_imgs, 'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)', p0=1, p1=patch_size, p2=patch_size, h=N_patches, w=N_patches) if norm: MEAN = torch.from_numpy(np.array((0.485, 0.456, 0.406))[None, :, None, None, None]).cuda().half() STD = torch.from_numpy(np.array((0.229, 0.224, 0.225))[None, :, None, None, None]).cuda().half() rec_imgs = (rec_imgs - MEAN) / STD return rec_imgs def upsample_masks(masks, size, thresh=0.5): shape = masks.shape dtype = masks.dtype h, w = shape[-2:] H, W = size if (H == h) and (W == w): return masks elif (H < h) and (W < w): s = (h // H, w // W) return masks[..., ::s[0], ::s[1]] masks = masks.unsqueeze(-2).unsqueeze(-1) masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w) if ((H % h) == 0) and ((W % w) == 0): masks = masks.view(*shape[:-2], H, W) else: _H = np.prod(masks.shape[-4:-2]) _W = np.prod(masks.shape[-2:]) masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh masks = masks.view(*shape[:2], H, W).to(masks.dtype) return masks def get_keypoints_batch(model, x, n_samples, n_rounds, frac=0.25, mask=None, pool='avg', ): """x = image pair tensor n_samples = number of potential candidates to look at on each round (produces one new unmasked per round) n_rounds = total number of unmasked patches frac = how often to do random sampling vs error-based sampling mask = initial mask """ # .half() B = x.shape[0] IMAGE_SIZE = [224, 224] predictor = model patch_size = predictor.patch_size[-1] num_frames = predictor.num_frames patch_num = IMAGE_SIZE[0] // patch_size # this is setup for getting per-patch error if pool == 'avg': pool_op = torch.nn.AvgPool2d(patch_size, stride=patch_size) elif pool == 'max': pool_op = torch.nn.MaxPool2d(patch_size, stride=patch_size) # initiazing rng rng = np.random.RandomState(seed=0) n_patches = patch_num * patch_num # initializing mask at the fully masked state mshape = num_frames * patch_num * patch_num mshape_masked = (num_frames - 1) * patch_num * patch_num if mask is None: mask = torch.ones([B, mshape], dtype=torch.bool) mask[:, :mshape_masked] = False err_array = [] choices = [] # flows = [] for round_num in range(n_rounds): # print(round_num) # get the current prediction with current state of the mask # .... produces out_flow b/c it's with head-motion condition out = unpatchify(predictor(x, mask, forward_full=True)) # print(out.shape) keypoint_recon = out.clone() # flow = teacher.predict_flow(out) # flows.append(flow) # get the error map err_mat = (out[:, :, 0] - x[:, :, -1]).abs().mean(1) # pool it to patch-size pooled_err = pool_op(err_mat[:, None]) # flatten the rror flat_pooled_error = pooled_err.flatten(1, 3) # set error to be zero where the mask is unmasked so it doesn't interfere flat_pooled_error[mask[:, -n_patches:] == False] = 0 # sort patches by where the error is highest err_sort = torch.argsort(flat_pooled_error, -1) new_mask = mask.clone().detach() errors = [] tries = [] err_choices = 0 # look at various candidates to reveal in the next round for sample_num in range(n_samples): # if sample_num % 10 == 0: # print("%d/%d" % (sample_num, n_samples)) # either randomly sample err_choices += 1 new_try = (num_frames - 1) * n_patches + err_sort[:, -1 * err_choices] tries.append(new_try) for k in range(B): new_mask[k, new_try[k]] = False reshaped_new_mask = upsample_masks( new_mask.view(B, num_frames, IMAGE_SIZE[1] // patch_size, IMAGE_SIZE[1] // patch_size)[:, (num_frames - 1):], IMAGE_SIZE)[:, 0] # print(reshaped_new_mask.sum()) out = unpatchify(predictor(x, new_mask, forward_full=True)) abs_error = (out[:, :, 0] - x[:, :, -1]).abs().sum(1).cpu() masked_abs_error = abs_error * reshaped_new_mask error = masked_abs_error.flatten(1, 2).sum(-1) errors.append(error) # take the best one for k in range(B): new_mask[k, new_try[k]] = True errors = torch.stack(errors, 1) tries = torch.stack(tries, 1) best_ind = torch.argmin(errors, dim=-1) best = torch.tensor([tries[k, best_ind[k]] for k in range(B)]) choices.append(best) err_array.append(errors) # print(best) for k in range(B): mask[k, best[k]] = False feat = predictor(x, mask, forward_full=True, return_features=True) feat = feat#[:, :784*2] choices = torch.stack(choices, 1) #get x y coordinates of the keypoints choices = choices % mshape_masked choices_x = choices % (patch_num) choices_y = choices // (patch_num) choices = torch.stack([choices_x, choices_y], 2) out = unpatchify(predictor(x, mask, forward_full=True), norm=False) keypoint_recon = out[0, :, 0].permute(1, 2, 0).detach().cpu().numpy() * 255 return mask, choices, err_array, feat, keypoint_recon.astype('uint8')