import kornia
import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
import cwm.eval.Flow.masking_flow as masking
def boltzmann(x, beta=1, eps=1e-9):
if beta is None:
return x
x = torch.exp(x * beta)
return x / x.amax((-1,-2), keepdim=True).clamp(min=eps)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def imagenet_normalize(x, temporal_dim=1):
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(x.device)[None,None,:,None,None].to(x)
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(x.device)[None,None,:,None,None].to(x)
if temporal_dim == 2:
mean = mean.transpose(1,2)
std = std.transpose(1,2)
return (x - mean) / std
def imagenet_unnormalize(x, temporal_dim=2):
device = x.device
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x)
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x)
if temporal_dim == 2:
mean = mean.transpose(1,2)
std = std.transpose(1,2)
x = x*std + mean
return x
def coordinate_ims(batch_size, seq_length, imsize, normalize=True, dtype_out=torch.float32):
static = False
if seq_length == 0:
static = True
seq_length = 1
B = batch_size
T = seq_length
H,W = imsize
ones = torch.ones([B,H,W,1], dtype=dtype_out)
if normalize:
h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=dtype_out))
h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5)
w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=dtype_out))
w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5)
h = torch.arange(H).to(ones).view(1,H,1,1) * ones
w = torch.arange(W).to(ones).view(1,1,W,1) * ones
h = torch.stack([h]*T, 1)
w = torch.stack([w]*T, 1)
hw_ims = torch.cat([h,w], -1)
if static:
hw_ims = hw_ims[:,0]
return hw_ims
def get_distribution_centroid(dist, eps=1e-9, normalize=False):
B,T,C,H,W = dist.shape
assert C == 1
dist_sum = dist.sum((-2, -1), keepdim=True).clamp(min=eps)
dist = dist / dist_sum
grid = coordinate_ims(B, T, [H,W], normalize=normalize).to(dist.device)
grid = grid.permute(0,1,4,2,3)
centroid = (grid * dist).sum((-2,-1))
return centroid
class FlowToRgb(object):
def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False):
self.max_speed = max_speed
self.from_image_coordinates = from_image_coordinates
self.from_sampling_grid = from_sampling_grid
def __call__(self, flow):
assert flow.size(-3) == 2, flow.shape
if self.from_sampling_grid:
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
flow_y = -flow_y
elif not self.from_image_coordinates:
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3)
flow_h, flow_w = torch.split(flow, [1,1], dim=-3)
flow_x, flow_y = [flow_w, -flow_h]
angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi
speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed
hue = torch.fmod(angle, torch.tensor(2 * np.pi))
sat = torch.ones_like(hue)
val = speed
hsv = torch.cat([hue, sat, val], -3)
rgb = kornia.color.hsv_to_rgb(hsv)
return rgb
class Patchify(nn.Module):
"""Convert a set of images or a movie into patch vectors"""
def __init__(self,
patch_size=(16, 16),
self.temporal_dim = temporal_dim
assert self.temporal_dim in [1, 2], self.temporal_dim
self._squeeze_channel_dim = squeeze_channel_dim
def num_patches(self):
if (self.T is None) or (self.H is None) or (self.W is None):
return None
return (self.T // self.pt) * (self.H // self.ph) * (self.W // self.pw)
def set_patch_size(self, patch_size):
self.patch_size = patch_size
if len(self.patch_size) == 2:
self.ph, self.pw = self.patch_size
self.pt = 1
self._patches_are_3d = False
elif len(self.patch_size) == 3:
self.pt, self.ph, self.pw = self.patch_size
self._patches_are_3d = True
raise ValueError("patch_size must be a 2- or 3-tuple, but is %s" % self.patch_size)
self.shape_inp = self.rank_inp = self.H = self.W = self.T = None
self.D = self.C = self.E = self.embed_dim = None
def _check_shape(self, x):
self.shape_inp = x.shape
self.rank_inp = len(self.shape_inp)
self.H, self.W = self.shape_inp[-2:]
assert (self.H % self.ph) == 0 and (self.W % self.pw) == 0, (self.shape_inp, self.patch_size)
if (self.rank_inp == 5) and self._patches_are_3d:
self.T = self.shape_inp[self.temporal_dim]
assert (self.T % self.pt) == 0, (self.T, self.pt)
elif self.rank_inp == 5:
self.T = self.shape_inp[self.temporal_dim]
self.T = 1
def split_by_time(self, x):
shape = x.shape
assert shape[1] % self.T == 0, (shape, self.T)
return x.view(shape[0], self.T, shape[1] // self.T, *shape[2:])
def merge_by_time(self, x):
shape = x.shape
return x.view(shape[0], shape[1] * shape[2], *shape[3:])
def video_to_patches(self, x):
if self.rank_inp == 4:
assert self.pt == 1, (self.pt, x.shape)
x = rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw) c', ph=self.ph, pw=self.pw)
assert self.rank_inp == 5, (x.shape, self.rank_inp, self.shape_inp)
dim_order = 'b (t pt) c (h ph) (w pw)' if self.temporal_dim == 1 else 'b c (t pt) (h ph) (w pw)'
x = rearrange(x, dim_order + ' -> b (t h w) (pt ph pw) c', pt=self.pt, ph=self.ph, pw=self.pw)
self.N, self.D, self.C = x.shape[-3:]
self.embed_dim = self.E = self.D * self.C
return x
def patches_to_video(self, x):
shape = x.shape
rank = len(shape)
if rank == 4:
B, _N, _D, _C = shape
assert rank == 3, rank
B, _N, _E = shape
assert (_E % self.D == 0), (_E, self.D)
x = x.view(B, _N, self.D, -1)
if _N < self.num_patches:
masked_patches = self.get_masked_patches(
num_patches=(self.num_patches - _N),
x = torch.cat([x, masked_patches], 1)
x = rearrange(
'b (t h w) (pt ph pw) c -> b c (t pt) (h ph) (w pw)',
pt=self.pt, ph=self.ph, pw=self.pw,
t=(self.T // self.pt), h=(self.H // self.ph), w=(self.W // self.pw))
if self.rank_inp == 5 and (self.temporal_dim == 1):
x = x.transpose(1, 2)
elif self.rank_inp == 4:
assert x.shape[2] == 1, x.shape
x = x[:, :, 0]
return x
def get_masked_patches(x, num_patches, mask_mode='zeros'):
shape = x.shape
patches_shape = (shape[0], num_patches, *shape[2:])
if mask_mode == 'zeros':
return torch.zeros(patches_shape).to(x.device).to(x.dtype).detach()
elif mask_mode == 'gray':
return 0.5 * torch.ones(patches_shape).to(x.device).to(x.dtype).detach()
raise NotImplementedError("Haven't implemented mask_mode == %s" % mask_mode)
def average_within_patches(self, z):
if len(z.shape) == 3:
z = rearrange(z, 'b n (d c) -> b n d c', c=self.C)
return z.mean(-2, True).repeat(1, 1, z.shape[-2], 1)
def forward(self, x, to_video=False, mask_mode='zeros'):
if not to_video:
x = self.video_to_patches(x)
return x if not self._squeeze_channel_dim else x.view(x.size(0), self.N, -1)
else: # x are patches
assert (self.shape_inp is not None) and (self.num_patches is not None)
self.mask_mode = mask_mode
x = self.patches_to_video(x)
return x
class DerivativeFlowGenerator(nn.Module):
"""Estimate flow of a two-frame predictor using torch autograd"""
def __init__(self,
super(DerivativeFlowGenerator, self).__init__()
self.predictor = predictor
self.patchify = Patchify(self.patch_size, temporal_dim=1, squeeze_channel_dim=True)
self.imagenet_normalize_inputs = imagenet_normalize_inputs
self.perturbation_patch_size = self._get_patch_size(perturbation_patch_size) or self.patch_size
self.aggregation_patch_size = self._get_patch_size(aggregation_patch_size) or self.patch_size
self.agg_patchify = Patchify(self.aggregation_patch_size,
self.agg_channel_func = agg_channel_func or (lambda x: F.relu(x).sum(-3, True))
self.average_jacobian = average_jacobian
self.confidence_thresh = confidence_thresh
self.num_samples = num_samples
self.leave_one_out_sampling = leave_one_out_sampling
self.agg_power = agg_power
self.t_dim = temporal_dim
def _get_patch_size(self, p):
if p is None:
return None
elif isinstance(p, int):
return (1, p, p)
elif len(p) == 2:
return (1, p[0], p[1])
assert len(p) == 3, p
return (p[0], p[1], p[2])
def set_temporal_dim(self, t_dim):
if t_dim == 1:
self.predictor.t_dim = 1
self.predictor.c_dim = 2
elif t_dim == 2:
self.predictor.c_dim = 1
self.predictor.t_dim = 2
raise ValueError("temporal_dim must be 1 or 2")
def c_dim(self):
if self.predictor is None:
return None
return self.predictor.c_dim
def patch_size(self):
if self.predictor is None:
return None
elif hasattr(self.predictor, 'patch_size'):
return self.predictor.patch_size
elif hasattr(self.predictor.encoder.patch_embed, 'proj'):
return self.predictor.encoder.patch_embed.proj.kernel_size
return None
def S(self):
return self.num_samples
def sequence_length(self):
if self.predictor is None:
return None
elif hasattr(self.predictor, 'sequence_length'):
return self.predictor.sequence_length
elif hasattr(self.predictor, 'num_frames'):
return self.predictor.num_frames
return 2
def mask_shape(self):
if self.predictor is None:
return None
elif hasattr(self.predictor, 'mask_shape'):
return self.predictor.mask_shape
assert self.patch_size is not None
pt, ph, pw = self.patch_size
return (self.sequence_length // pt,
self.inp_shape[-2] // ph,
self.inp_shape[-1] // pw)
def perturbation_mask_shape(self):
return (
self.inp_shape[-2] // self.perturbation_patch_size[-2],
self.inp_shape[-1] // self.perturbation_patch_size[-1]
def p_mask_shape(self):
return self.perturbation_mask_shape
def aggregation_mask_shape(self):
return (
self.inp_shape[-2] // self.aggregation_patch_size[-2],
self.inp_shape[-1] // self.aggregation_patch_size[-1]
def a_mask_shape(self):
return self.aggregation_mask_shape
def get_perturbation_input(self, x):
y = torch.zeros((self.B, *self.p_mask_shape), dtype=x.dtype, device=x.device, requires_grad=True)
y = y.unsqueeze(2).repeat(1, 1, x.shape[2], 1, 1)
return y
def pred_patches_to_video(self, y, x, mask):
"""input at visible positions, preds at masked positions"""
B, C = y.shape[0], y.shape[-1]
self.patchify.D = np.prod(self.patch_size)
x = self.patchify(x)
y_out = torch.zeros_like(x)
x_vis = x[~mask]
y_out[~mask] = x_vis.view(-1, C)
y_out[mask] = y.view(-1, C)
y_out[mask] = y.reshape(-1, C)
return self.patchify(y_out, to_video=True)
def set_image_size(self, *args, **kwargs):
assert self.predictor is not None, "Can't set the image size without a predictor"
if hasattr(self.predictor, 'set_image_size'):
self.predictor.set_image_size(*args, **kwargs)
self.predictor.image_size = args[0]
def predict(self, x=None, mask=None, forward_full=False):
if x is None:
x = self.x
if mask is None:
mask = self.generate_mask(x)
y = self.predictor(
mask if (x.size(0) == 1) else self.mask_rectangularizer(mask), forward_full=forward_full)
y = self.pred_patches_to_video(y, x, mask=mask)
frame = -1 % y.size(1)
y = y[:, frame:frame + 1]
return y
def _get_perturbation_func(self, x=None, mask=None):
if (x is not None):
self.set_input(x, mask)
def forward_mini_image(y):
y = y.repeat_interleave(self.perturbation_patch_size[-2], -2)
y = y.repeat_interleave(self.perturbation_patch_size[-1], -1)
x_pred = self.predict(self.x + y, self.mask)
x_pred = self.agg_patchify(x_pred).mean(-2).sum(-1).view(self.B, *self.a_mask_shape)
return x_pred[self.targets]
return forward_mini_image
def _postprocess_jacobian(self, jac):
_jac = torch.zeros((self.B, *self.a_mask_shape, *jac.shape[1:])).to(jac.device).to(jac.dtype)
_jac[self.targets] = jac
jac = self.agg_channel_func(_jac)
assert jac.size(-3) == 1, jac.shape
jac = jac.squeeze(-3)[..., 0, :, :] # derivative w.r.t. first frame and agg channels
jac = jac.view(self.B, self.a_mask_shape[-2], self.a_mask_shape[-1],
self.B, self.p_mask_shape[-2], self.p_mask_shape[-1])
bs = torch.arange(0, self.B).long().to(jac.device)
jac = jac[bs, :, :, bs, :, :] # take diagonal
return jac
def _confident_jacobian(self, jac):
if self.confidence_thresh is None:
return torch.ones_like(jac[:, None, ..., 0, 0])
conf = (jac.amax((-2, -1)) > self.confidence_thresh).float()[:, None]
return conf
def set_input(self, x, mask=None, timestamps=None):
shape = x.shape
if len(shape) == 4:
x = x.unsqueeze(1)
assert len(shape) == 5, \
"Input must be a movie of shape [B,T,C,H,W]" + \
"or a single frame of shape [B,C,H,W]"
self.inp_shape = x.shape
self.x = x
self.B = self.inp_shape[0]
self.T = self.inp_shape[1]
self.C = self.inp_shape[2]
if mask is not None:
self.mask = mask
if timestamps is not None:
self.timestamps = timestamps
def _preprocess(self, x):
if self.imagenet_normalize_inputs:
x = imagenet_normalize(x)
if self.t_dim != 1:
x = x.transpose(self.t_dim, self.c_dim)
return x
def _jacobian_to_flows(self, jac):
if self.agg_power is None:
jac = (jac == jac.amax((-2, -1), True)).float()
jac = torch.pow(jac, self.agg_power)
jac = jac.view(self.B * np.prod(self.a_mask_shape[-2:]), 1, 1, *self.p_mask_shape[-2:])
centroids = get_distribution_centroid(jac, normalize=False).view(
self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], 2)
rescale = [self.a_mask_shape[-2] / self.p_mask_shape[-2],
self.a_mask_shape[-1] / self.p_mask_shape[-1]]
centroids = centroids * torch.tensor(rescale, device=centroids.device).view(1, 1, 1, 2)
flows = centroids - \
coordinate_ims(1, 0, self.a_mask_shape[-2:], normalize=False).to(jac.device)
flows = flows.permute(0, 3, 1, 2)
px_scale = torch.tensor(self.aggregation_patch_size[-2:]).float().to(flows.device).view(1, 2, 1, 1)
flows *= px_scale
return flows
def set_targets(self, targets=None, frame=-1):
frame = frame % self.mask_shape[0]
if targets is None:
targets = self.get_mask_image(self.mask)[:, frame:frame + 1]
assert len(targets.shape) == 4, targets.shape
targets = targets[:, frame:frame + 1]
self.targets = ~masking.upsample_masks(~targets, self.a_mask_shape[-2:])
def _get_mask_partition(self, mask):
mask = self.get_mask_image(mask)
mask_list = masking.partition_masks(
mask[:, 1:], num_samples=self.S, leave_one_out=self.leave_one_out_sampling)
return [torch.cat([mask[:, 0:1].view(m.size(0), -1), m], -1)
for m in mask_list]
def _compute_jacobian(self, y):
perturbation_func = self._get_perturbation_func()
jac = torch.autograd.functional.jacobian(
jac = self._postprocess_jacobian(jac)
return jac
def _upsample_mask(self, mask):
return masking.upsample_masks(
mask.view(mask.size(0), -1, *self.mask_shape[-2:]).float(), self.inp_shape[-2:])
def get_mask_image(self, mask, upsample=False, invert=False, shape=None):
if shape is None:
shape = self.mask_shape
mask = mask.view(-1, *shape)
if upsample:
mask = self._upsample_mask(mask)
if invert:
mask = 1 - mask
return mask
def forward(self, x, mask, targets=None):
self.set_input(x, mask)
y = self.get_perturbation_input(x)
mask_list = self._get_mask_partition(mask)
jacobian, flows, confident = [], [], []
for s, mask_sample in enumerate(mask_list):
self.set_input(x, mask_sample)
import time
t1 = time.time()
jac = self._compute_jacobian(y)
conf_jac = masking.upsample_masks(self._confident_jacobian(jac), self.a_mask_shape[-2:])
if not self.average_jacobian:
flow = self._jacobian_to_flows(jac) * self.targets * conf_jac * \
masking.upsample_masks(self.get_mask_image(self.mask)[:, 1:], self.a_mask_shape[-2:])
t2 = time.time()
print(t2 - t1)
jacobian = torch.stack(jacobian, -1)
confident = torch.stack(confident, -1)
valid = torch.stack([masking.upsample_masks(
self.get_mask_image(m)[:, 1:], self.a_mask_shape[-2:]) for m in mask_list], -1)
valid = valid * confident
if self.average_jacobian:
_valid = valid[:, 0].unsqueeze(-2).unsqueeze(-2)
jac = (jacobian * _valid.float()).sum(-1) / _valid.float().sum(-1).clamp(min=1)
flows = self._jacobian_to_flows(jac) * \
masking.upsample_masks(_valid[:, None, ..., 0, 0, :].amax(-1).bool(), self.a_mask_shape[-2:])
if targets is not None:
flows *= self.targets
flows = torch.stack(flows, -1)
flows = flows.sum(-1) / valid.float().sum(-1).clamp(min=1)
valid = valid * (targets[:, -1:].unsqueeze(-1) if targets is not None else 1)
return (jacobian, flows, valid)