Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import kornia | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
from torch import nn | |
import cwm.eval.Flow.masking_flow as masking | |
def boltzmann(x, beta=1, eps=1e-9): | |
if beta is None: | |
return x | |
x = torch.exp(x * beta) | |
return x / x.amax((-1,-2), keepdim=True).clamp(min=eps) | |
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) | |
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) | |
def imagenet_normalize(x, temporal_dim=1): | |
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(x.device)[None,None,:,None,None].to(x) | |
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(x.device)[None,None,:,None,None].to(x) | |
if temporal_dim == 2: | |
mean = mean.transpose(1,2) | |
std = std.transpose(1,2) | |
return (x - mean) / std | |
def imagenet_unnormalize(x, temporal_dim=2): | |
device = x.device | |
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x) | |
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x) | |
if temporal_dim == 2: | |
mean = mean.transpose(1,2) | |
std = std.transpose(1,2) | |
x = x*std + mean | |
return x | |
def coordinate_ims(batch_size, seq_length, imsize, normalize=True, dtype_out=torch.float32): | |
static = False | |
if seq_length == 0: | |
static = True | |
seq_length = 1 | |
B = batch_size | |
T = seq_length | |
H,W = imsize | |
ones = torch.ones([B,H,W,1], dtype=dtype_out) | |
if normalize: | |
h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=dtype_out)) | |
h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5) | |
w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=dtype_out)) | |
w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5) | |
else: | |
h = torch.arange(H).to(ones).view(1,H,1,1) * ones | |
w = torch.arange(W).to(ones).view(1,1,W,1) * ones | |
h = torch.stack([h]*T, 1) | |
w = torch.stack([w]*T, 1) | |
hw_ims = torch.cat([h,w], -1) | |
if static: | |
hw_ims = hw_ims[:,0] | |
return hw_ims | |
def get_distribution_centroid(dist, eps=1e-9, normalize=False): | |
B,T,C,H,W = dist.shape | |
assert C == 1 | |
dist_sum = dist.sum((-2, -1), keepdim=True).clamp(min=eps) | |
dist = dist / dist_sum | |
grid = coordinate_ims(B, T, [H,W], normalize=normalize).to(dist.device) | |
grid = grid.permute(0,1,4,2,3) | |
centroid = (grid * dist).sum((-2,-1)) | |
return centroid | |
class FlowToRgb(object): | |
def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False): | |
self.max_speed = max_speed | |
self.from_image_coordinates = from_image_coordinates | |
self.from_sampling_grid = from_sampling_grid | |
def __call__(self, flow): | |
assert flow.size(-3) == 2, flow.shape | |
if self.from_sampling_grid: | |
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) | |
flow_y = -flow_y | |
elif not self.from_image_coordinates: | |
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) | |
else: | |
flow_h, flow_w = torch.split(flow, [1,1], dim=-3) | |
flow_x, flow_y = [flow_w, -flow_h] | |
angle = torch.atan2(flow_y, flow_x) # in radians from -pi to pi | |
speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed | |
hue = torch.fmod(angle, torch.tensor(2 * np.pi)) | |
sat = torch.ones_like(hue) | |
val = speed | |
hsv = torch.cat([hue, sat, val], -3) | |
rgb = kornia.color.hsv_to_rgb(hsv) | |
return rgb | |
class Patchify(nn.Module): | |
"""Convert a set of images or a movie into patch vectors""" | |
def __init__(self, | |
patch_size=(16, 16), | |
temporal_dim=1, | |
squeeze_channel_dim=True | |
): | |
super().__init__() | |
self.set_patch_size(patch_size) | |
self.temporal_dim = temporal_dim | |
assert self.temporal_dim in [1, 2], self.temporal_dim | |
self._squeeze_channel_dim = squeeze_channel_dim | |
def num_patches(self): | |
if (self.T is None) or (self.H is None) or (self.W is None): | |
return None | |
else: | |
return (self.T // self.pt) * (self.H // self.ph) * (self.W // self.pw) | |
def set_patch_size(self, patch_size): | |
self.patch_size = patch_size | |
if len(self.patch_size) == 2: | |
self.ph, self.pw = self.patch_size | |
self.pt = 1 | |
self._patches_are_3d = False | |
elif len(self.patch_size) == 3: | |
self.pt, self.ph, self.pw = self.patch_size | |
self._patches_are_3d = True | |
else: | |
raise ValueError("patch_size must be a 2- or 3-tuple, but is %s" % self.patch_size) | |
self.shape_inp = self.rank_inp = self.H = self.W = self.T = None | |
self.D = self.C = self.E = self.embed_dim = None | |
def _check_shape(self, x): | |
self.shape_inp = x.shape | |
self.rank_inp = len(self.shape_inp) | |
self.H, self.W = self.shape_inp[-2:] | |
assert (self.H % self.ph) == 0 and (self.W % self.pw) == 0, (self.shape_inp, self.patch_size) | |
if (self.rank_inp == 5) and self._patches_are_3d: | |
self.T = self.shape_inp[self.temporal_dim] | |
assert (self.T % self.pt) == 0, (self.T, self.pt) | |
elif self.rank_inp == 5: | |
self.T = self.shape_inp[self.temporal_dim] | |
else: | |
self.T = 1 | |
def split_by_time(self, x): | |
shape = x.shape | |
assert shape[1] % self.T == 0, (shape, self.T) | |
return x.view(shape[0], self.T, shape[1] // self.T, *shape[2:]) | |
def merge_by_time(self, x): | |
shape = x.shape | |
return x.view(shape[0], shape[1] * shape[2], *shape[3:]) | |
def video_to_patches(self, x): | |
if self.rank_inp == 4: | |
assert self.pt == 1, (self.pt, x.shape) | |
x = rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw) c', ph=self.ph, pw=self.pw) | |
else: | |
assert self.rank_inp == 5, (x.shape, self.rank_inp, self.shape_inp) | |
dim_order = 'b (t pt) c (h ph) (w pw)' if self.temporal_dim == 1 else 'b c (t pt) (h ph) (w pw)' | |
x = rearrange(x, dim_order + ' -> b (t h w) (pt ph pw) c', pt=self.pt, ph=self.ph, pw=self.pw) | |
self.N, self.D, self.C = x.shape[-3:] | |
self.embed_dim = self.E = self.D * self.C | |
return x | |
def patches_to_video(self, x): | |
shape = x.shape | |
rank = len(shape) | |
if rank == 4: | |
B, _N, _D, _C = shape | |
else: | |
assert rank == 3, rank | |
B, _N, _E = shape | |
assert (_E % self.D == 0), (_E, self.D) | |
x = x.view(B, _N, self.D, -1) | |
if _N < self.num_patches: | |
masked_patches = self.get_masked_patches( | |
x, | |
num_patches=(self.num_patches - _N), | |
mask_mode=self.mask_mode) | |
x = torch.cat([x, masked_patches], 1) | |
x = rearrange( | |
x, | |
'b (t h w) (pt ph pw) c -> b c (t pt) (h ph) (w pw)', | |
pt=self.pt, ph=self.ph, pw=self.pw, | |
t=(self.T // self.pt), h=(self.H // self.ph), w=(self.W // self.pw)) | |
if self.rank_inp == 5 and (self.temporal_dim == 1): | |
x = x.transpose(1, 2) | |
elif self.rank_inp == 4: | |
assert x.shape[2] == 1, x.shape | |
x = x[:, :, 0] | |
return x | |
def get_masked_patches(x, num_patches, mask_mode='zeros'): | |
shape = x.shape | |
patches_shape = (shape[0], num_patches, *shape[2:]) | |
if mask_mode == 'zeros': | |
return torch.zeros(patches_shape).to(x.device).to(x.dtype).detach() | |
elif mask_mode == 'gray': | |
return 0.5 * torch.ones(patches_shape).to(x.device).to(x.dtype).detach() | |
else: | |
raise NotImplementedError("Haven't implemented mask_mode == %s" % mask_mode) | |
def average_within_patches(self, z): | |
if len(z.shape) == 3: | |
z = rearrange(z, 'b n (d c) -> b n d c', c=self.C) | |
return z.mean(-2, True).repeat(1, 1, z.shape[-2], 1) | |
def forward(self, x, to_video=False, mask_mode='zeros'): | |
if not to_video: | |
self._check_shape(x) | |
x = self.video_to_patches(x) | |
return x if not self._squeeze_channel_dim else x.view(x.size(0), self.N, -1) | |
else: # x are patches | |
assert (self.shape_inp is not None) and (self.num_patches is not None) | |
self.mask_mode = mask_mode | |
x = self.patches_to_video(x) | |
return x | |
class DerivativeFlowGenerator(nn.Module): | |
"""Estimate flow of a two-frame predictor using torch autograd""" | |
def __init__(self, | |
predictor, | |
perturbation_patch_size=None, | |
aggregation_patch_size=None, | |
agg_power=None, | |
agg_channel_func=None, | |
num_samples=1, | |
leave_one_out_sampling=False, | |
average_jacobian=True, | |
confidence_thresh=None, | |
temporal_dim=2, | |
imagenet_normalize_inputs=True): | |
super(DerivativeFlowGenerator, self).__init__() | |
self.predictor = predictor | |
self.patchify = Patchify(self.patch_size, temporal_dim=1, squeeze_channel_dim=True) | |
self.set_temporal_dim(temporal_dim) | |
self.imagenet_normalize_inputs = imagenet_normalize_inputs | |
self.perturbation_patch_size = self._get_patch_size(perturbation_patch_size) or self.patch_size | |
self.aggregation_patch_size = self._get_patch_size(aggregation_patch_size) or self.patch_size | |
self.agg_patchify = Patchify(self.aggregation_patch_size, | |
temporal_dim=1, | |
squeeze_channel_dim=False) | |
self.agg_channel_func = agg_channel_func or (lambda x: F.relu(x).sum(-3, True)) | |
self.average_jacobian = average_jacobian | |
self.confidence_thresh = confidence_thresh | |
self.num_samples = num_samples | |
self.leave_one_out_sampling = leave_one_out_sampling | |
self.agg_power = agg_power | |
self.t_dim = temporal_dim | |
def _get_patch_size(self, p): | |
if p is None: | |
return None | |
elif isinstance(p, int): | |
return (1, p, p) | |
elif len(p) == 2: | |
return (1, p[0], p[1]) | |
else: | |
assert len(p) == 3, p | |
return (p[0], p[1], p[2]) | |
def set_temporal_dim(self, t_dim): | |
if t_dim == 1: | |
self.predictor.t_dim = 1 | |
self.predictor.c_dim = 2 | |
elif t_dim == 2: | |
self.predictor.c_dim = 1 | |
self.predictor.t_dim = 2 | |
else: | |
raise ValueError("temporal_dim must be 1 or 2") | |
def c_dim(self): | |
if self.predictor is None: | |
return None | |
return self.predictor.c_dim | |
def patch_size(self): | |
if self.predictor is None: | |
return None | |
elif hasattr(self.predictor, 'patch_size'): | |
return self.predictor.patch_size | |
elif hasattr(self.predictor.encoder.patch_embed, 'proj'): | |
return self.predictor.encoder.patch_embed.proj.kernel_size | |
else: | |
return None | |
def S(self): | |
return self.num_samples | |
def sequence_length(self): | |
if self.predictor is None: | |
return None | |
elif hasattr(self.predictor, 'sequence_length'): | |
return self.predictor.sequence_length | |
elif hasattr(self.predictor, 'num_frames'): | |
return self.predictor.num_frames | |
else: | |
return 2 | |
def mask_shape(self): | |
if self.predictor is None: | |
return None | |
elif hasattr(self.predictor, 'mask_shape'): | |
return self.predictor.mask_shape | |
assert self.patch_size is not None | |
pt, ph, pw = self.patch_size | |
return (self.sequence_length // pt, | |
self.inp_shape[-2] // ph, | |
self.inp_shape[-1] // pw) | |
def perturbation_mask_shape(self): | |
return ( | |
self.mask_shape[0], | |
self.inp_shape[-2] // self.perturbation_patch_size[-2], | |
self.inp_shape[-1] // self.perturbation_patch_size[-1] | |
) | |
def p_mask_shape(self): | |
return self.perturbation_mask_shape | |
def aggregation_mask_shape(self): | |
return ( | |
1, | |
self.inp_shape[-2] // self.aggregation_patch_size[-2], | |
self.inp_shape[-1] // self.aggregation_patch_size[-1] | |
) | |
def a_mask_shape(self): | |
return self.aggregation_mask_shape | |
def get_perturbation_input(self, x): | |
self.set_input(x) | |
y = torch.zeros((self.B, *self.p_mask_shape), dtype=x.dtype, device=x.device, requires_grad=True) | |
y = y.unsqueeze(2).repeat(1, 1, x.shape[2], 1, 1) | |
return y | |
def pred_patches_to_video(self, y, x, mask): | |
"""input at visible positions, preds at masked positions""" | |
B, C = y.shape[0], y.shape[-1] | |
self.patchify._check_shape(x) | |
self.patchify.D = np.prod(self.patch_size) | |
x = self.patchify(x) | |
y_out = torch.zeros_like(x) | |
x_vis = x[~mask] | |
y_out[~mask] = x_vis.view(-1, C) | |
try: | |
y_out[mask] = y.view(-1, C) | |
except: | |
y_out[mask] = y.reshape(-1, C) | |
return self.patchify(y_out, to_video=True) | |
def set_image_size(self, *args, **kwargs): | |
assert self.predictor is not None, "Can't set the image size without a predictor" | |
if hasattr(self.predictor, 'set_image_size'): | |
self.predictor.set_image_size(*args, **kwargs) | |
else: | |
self.predictor.image_size = args[0] | |
def predict(self, x=None, mask=None, forward_full=False): | |
if x is None: | |
x = self.x | |
if mask is None: | |
mask = self.generate_mask(x) | |
self.set_image_size(x.shape[-2:]) | |
y = self.predictor( | |
self._preprocess(x), | |
mask if (x.size(0) == 1) else self.mask_rectangularizer(mask), forward_full=forward_full) | |
y = self.pred_patches_to_video(y, x, mask=mask) | |
frame = -1 % y.size(1) | |
y = y[:, frame:frame + 1] | |
return y | |
def _get_perturbation_func(self, x=None, mask=None): | |
if (x is not None): | |
self.set_input(x, mask) | |
def forward_mini_image(y): | |
y = y.repeat_interleave(self.perturbation_patch_size[-2], -2) | |
y = y.repeat_interleave(self.perturbation_patch_size[-1], -1) | |
x_pred = self.predict(self.x + y, self.mask) | |
x_pred = self.agg_patchify(x_pred).mean(-2).sum(-1).view(self.B, *self.a_mask_shape) | |
return x_pred[self.targets] | |
return forward_mini_image | |
def _postprocess_jacobian(self, jac): | |
_jac = torch.zeros((self.B, *self.a_mask_shape, *jac.shape[1:])).to(jac.device).to(jac.dtype) | |
_jac[self.targets] = jac | |
jac = self.agg_channel_func(_jac) | |
assert jac.size(-3) == 1, jac.shape | |
jac = jac.squeeze(-3)[..., 0, :, :] # derivative w.r.t. first frame and agg channels | |
jac = jac.view(self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], | |
self.B, self.p_mask_shape[-2], self.p_mask_shape[-1]) | |
bs = torch.arange(0, self.B).long().to(jac.device) | |
jac = jac[bs, :, :, bs, :, :] # take diagonal | |
return jac | |
def _confident_jacobian(self, jac): | |
if self.confidence_thresh is None: | |
return torch.ones_like(jac[:, None, ..., 0, 0]) | |
conf = (jac.amax((-2, -1)) > self.confidence_thresh).float()[:, None] | |
return conf | |
def set_input(self, x, mask=None, timestamps=None): | |
shape = x.shape | |
if len(shape) == 4: | |
x = x.unsqueeze(1) | |
else: | |
assert len(shape) == 5, \ | |
"Input must be a movie of shape [B,T,C,H,W]" + \ | |
"or a single frame of shape [B,C,H,W]" | |
self.inp_shape = x.shape | |
self.x = x | |
self.B = self.inp_shape[0] | |
self.T = self.inp_shape[1] | |
self.C = self.inp_shape[2] | |
if mask is not None: | |
self.mask = mask | |
if timestamps is not None: | |
self.timestamps = timestamps | |
def _preprocess(self, x): | |
if self.imagenet_normalize_inputs: | |
x = imagenet_normalize(x) | |
if self.t_dim != 1: | |
x = x.transpose(self.t_dim, self.c_dim) | |
return x | |
def _jacobian_to_flows(self, jac): | |
if self.agg_power is None: | |
jac = (jac == jac.amax((-2, -1), True)).float() | |
else: | |
jac = torch.pow(jac, self.agg_power) | |
jac = jac.view(self.B * np.prod(self.a_mask_shape[-2:]), 1, 1, *self.p_mask_shape[-2:]) | |
centroids = get_distribution_centroid(jac, normalize=False).view( | |
self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], 2) | |
rescale = [self.a_mask_shape[-2] / self.p_mask_shape[-2], | |
self.a_mask_shape[-1] / self.p_mask_shape[-1]] | |
centroids = centroids * torch.tensor(rescale, device=centroids.device).view(1, 1, 1, 2) | |
flows = centroids - \ | |
coordinate_ims(1, 0, self.a_mask_shape[-2:], normalize=False).to(jac.device) | |
flows = flows.permute(0, 3, 1, 2) | |
px_scale = torch.tensor(self.aggregation_patch_size[-2:]).float().to(flows.device).view(1, 2, 1, 1) | |
flows *= px_scale | |
return flows | |
def set_targets(self, targets=None, frame=-1): | |
frame = frame % self.mask_shape[0] | |
if targets is None: | |
targets = self.get_mask_image(self.mask)[:, frame:frame + 1] | |
else: | |
assert len(targets.shape) == 4, targets.shape | |
targets = targets[:, frame:frame + 1] | |
self.targets = ~masking.upsample_masks(~targets, self.a_mask_shape[-2:]) | |
def _get_mask_partition(self, mask): | |
mask = self.get_mask_image(mask) | |
mask_list = masking.partition_masks( | |
mask[:, 1:], num_samples=self.S, leave_one_out=self.leave_one_out_sampling) | |
return [torch.cat([mask[:, 0:1].view(m.size(0), -1), m], -1) | |
for m in mask_list] | |
def _compute_jacobian(self, y): | |
perturbation_func = self._get_perturbation_func() | |
jac = torch.autograd.functional.jacobian( | |
perturbation_func, | |
y, | |
vectorize=False) | |
jac = self._postprocess_jacobian(jac) | |
return jac | |
def _upsample_mask(self, mask): | |
return masking.upsample_masks( | |
mask.view(mask.size(0), -1, *self.mask_shape[-2:]).float(), self.inp_shape[-2:]) | |
def get_mask_image(self, mask, upsample=False, invert=False, shape=None): | |
if shape is None: | |
shape = self.mask_shape | |
mask = mask.view(-1, *shape) | |
if upsample: | |
mask = self._upsample_mask(mask) | |
if invert: | |
mask = 1 - mask | |
return mask | |
def forward(self, x, mask, targets=None): | |
self.set_input(x, mask) | |
y = self.get_perturbation_input(x) | |
mask_list = self._get_mask_partition(mask) | |
jacobian, flows, confident = [], [], [] | |
for s, mask_sample in enumerate(mask_list): | |
self.set_input(x, mask_sample) | |
self.set_targets(targets) | |
import time | |
t1 = time.time() | |
jac = self._compute_jacobian(y) | |
conf_jac = masking.upsample_masks(self._confident_jacobian(jac), self.a_mask_shape[-2:]) | |
jacobian.append(jac) | |
confident.append(conf_jac) | |
if not self.average_jacobian: | |
flow = self._jacobian_to_flows(jac) * self.targets * conf_jac * \ | |
masking.upsample_masks(self.get_mask_image(self.mask)[:, 1:], self.a_mask_shape[-2:]) | |
flows.append(flow) | |
t2 = time.time() | |
print(t2 - t1) | |
jacobian = torch.stack(jacobian, -1) | |
confident = torch.stack(confident, -1) | |
valid = torch.stack([masking.upsample_masks( | |
self.get_mask_image(m)[:, 1:], self.a_mask_shape[-2:]) for m in mask_list], -1) | |
valid = valid * confident | |
if self.average_jacobian: | |
_valid = valid[:, 0].unsqueeze(-2).unsqueeze(-2) | |
jac = (jacobian * _valid.float()).sum(-1) / _valid.float().sum(-1).clamp(min=1) | |
flows = self._jacobian_to_flows(jac) * \ | |
masking.upsample_masks(_valid[:, None, ..., 0, 0, :].amax(-1).bool(), self.a_mask_shape[-2:]) | |
if targets is not None: | |
self.set_targets(targets) | |
flows *= self.targets | |
else: | |
flows = torch.stack(flows, -1) | |
flows = flows.sum(-1) / valid.float().sum(-1).clamp(min=1) | |
valid = valid * (targets[:, -1:].unsqueeze(-1) if targets is not None else 1) | |
return (jacobian, flows, valid) |