rahulvenkk
app.py updated
6dfcb0f
raw
history blame
6.48 kB
from einops import rearrange
import torch
import numpy as np
from torchvision import transforms
def unpatchify(labels, norm=True):
# Define the input tensor
B = labels.shape[0] # batch size
N_patches = int(np.sqrt(labels.shape[1])) # number of patches along each dimension
patch_size = int(np.sqrt(labels.shape[2] / 3)) # patch size along each dimension
channels = 3 # number of channels
rec_imgs = rearrange(labels, 'b n (p c) -> b n p c', c=3)
# Notice: To visualize the reconstruction video, we add the predict and the original mean and var of each patch.
rec_imgs = rearrange(rec_imgs,
'b (t h w) (p0 p1 p2) c -> b c (t p0) (h p1) (w p2)',
p0=1,
p1=patch_size,
p2=patch_size,
h=N_patches,
w=N_patches)
if norm:
MEAN = torch.from_numpy(np.array((0.485, 0.456, 0.406))[None, :, None, None, None]).cuda().half()
STD = torch.from_numpy(np.array((0.229, 0.224, 0.225))[None, :, None, None, None]).cuda().half()
rec_imgs = (rec_imgs - MEAN) / STD
return rec_imgs
def upsample_masks(masks, size, thresh=0.5):
shape = masks.shape
dtype = masks.dtype
h, w = shape[-2:]
H, W = size
if (H == h) and (W == w):
return masks
elif (H < h) and (W < w):
s = (h // H, w // W)
return masks[..., ::s[0], ::s[1]]
masks = masks.unsqueeze(-2).unsqueeze(-1)
masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w)
if ((H % h) == 0) and ((W % w) == 0):
masks = masks.view(*shape[:-2], H, W)
else:
_H = np.prod(masks.shape[-4:-2])
_W = np.prod(masks.shape[-2:])
masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh
masks = masks.view(*shape[:2], H, W).to(masks.dtype)
return masks
def get_keypoints_batch(model, x,
n_samples,
n_rounds,
frac=0.25,
mask=None,
pool='avg',
):
"""x = image pair tensor
n_samples = number of potential candidates to look at on each round
(produces one new unmasked per round)
n_rounds = total number of unmasked patches
frac = how often to do random sampling vs error-based sampling
mask = initial mask
"""
# .half()
B = x.shape[0]
IMAGE_SIZE = [224, 224]
predictor = model
patch_size = predictor.patch_size[-1]
num_frames = predictor.num_frames
patch_num = IMAGE_SIZE[0] // patch_size
# this is setup for getting per-patch error
if pool == 'avg':
pool_op = torch.nn.AvgPool2d(patch_size, stride=patch_size)
elif pool == 'max':
pool_op = torch.nn.MaxPool2d(patch_size, stride=patch_size)
# initiazing rng
rng = np.random.RandomState(seed=0)
n_patches = patch_num * patch_num
# initializing mask at the fully masked state
mshape = num_frames * patch_num * patch_num
mshape_masked = (num_frames - 1) * patch_num * patch_num
if mask is None:
mask = torch.ones([B, mshape], dtype=torch.bool)
mask[:, :mshape_masked] = False
err_array = []
choices = []
# flows = []
for round_num in range(n_rounds):
# print(round_num)
# get the current prediction with current state of the mask
# .... produces out_flow b/c it's with head-motion condition
out = unpatchify(predictor(x, mask, forward_full=True))
# print(out.shape)
keypoint_recon = out.clone()
# flow = teacher.predict_flow(out)
# flows.append(flow)
# get the error map
err_mat = (out[:, :, 0] - x[:, :, -1]).abs().mean(1)
# pool it to patch-size
pooled_err = pool_op(err_mat[:, None])
# flatten the rror
flat_pooled_error = pooled_err.flatten(1, 3)
# set error to be zero where the mask is unmasked so it doesn't interfere
flat_pooled_error[mask[:, -n_patches:] == False] = 0
# sort patches by where the error is highest
err_sort = torch.argsort(flat_pooled_error, -1)
new_mask = mask.clone().detach()
errors = []
tries = []
err_choices = 0
# look at various candidates to reveal in the next round
for sample_num in range(n_samples):
# if sample_num % 10 == 0:
# print("%d/%d" % (sample_num, n_samples))
# either randomly sample
err_choices += 1
new_try = (num_frames - 1) * n_patches + err_sort[:, -1 * err_choices]
tries.append(new_try)
for k in range(B):
new_mask[k, new_try[k]] = False
reshaped_new_mask = upsample_masks(
new_mask.view(B, num_frames, IMAGE_SIZE[1] // patch_size, IMAGE_SIZE[1] // patch_size)[:, (num_frames - 1):],
IMAGE_SIZE)[:, 0]
# print(reshaped_new_mask.sum())
out = unpatchify(predictor(x, new_mask, forward_full=True))
abs_error = (out[:, :, 0] - x[:, :, -1]).abs().sum(1).cpu()
masked_abs_error = abs_error * reshaped_new_mask
error = masked_abs_error.flatten(1, 2).sum(-1)
errors.append(error)
# take the best one
for k in range(B):
new_mask[k, new_try[k]] = True
errors = torch.stack(errors, 1)
tries = torch.stack(tries, 1)
best_ind = torch.argmin(errors, dim=-1)
best = torch.tensor([tries[k, best_ind[k]] for k in range(B)])
choices.append(best)
err_array.append(errors)
# print(best)
for k in range(B):
mask[k, best[k]] = False
feat = predictor(x, mask, forward_full=True, return_features=True)
feat = feat#[:, :784*2]
choices = torch.stack(choices, 1)
#get x y coordinates of the keypoints
choices = choices % mshape_masked
choices_x = choices % (patch_num)
choices_y = choices // (patch_num)
choices = torch.stack([choices_x, choices_y], 2)
out = unpatchify(predictor(x, mask, forward_full=True), norm=False)
keypoint_recon = out[0, :, 0].permute(1, 2, 0).detach().cpu().numpy() * 255
return mask, choices, err_array, feat, keypoint_recon.astype('uint8')