rahulvenkk
app.py updated
6dfcb0f
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
def upsample_masks(masks, size, thresh=0.5):
shape = masks.shape
dtype = masks.dtype
h, w = shape[-2:]
H, W = size
if (H == h) and (W == w):
return masks
elif (H < h) and (W < w):
s = (h // H, w // W)
return masks[..., ::s[0], ::s[1]]
masks = masks.unsqueeze(-2).unsqueeze(-1)
masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w)
if ((H % h) == 0) and ((W % w) == 0):
masks = masks.view(*shape[:-2], H, W)
else:
_H = np.prod(masks.shape[-4:-2])
_W = np.prod(masks.shape[-2:])
masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh
masks = masks.view(*shape[:2], H, W).to(masks.dtype)
return masks
def partition_masks(masks, num_samples=2, leave_one_out=False):
B = masks.shape[0]
S = num_samples
masks = masks.view(B, -1)
partitioned = [torch.ones_like(masks) for _ in range(S)]
for b in range(B):
vis_inds = torch.where(~masks[b])[0]
vis_inds = vis_inds[torch.randperm(vis_inds.size(0))]
if leave_one_out:
for s in range(S):
partitioned[s][b][vis_inds] = 0
partitioned[s][b][vis_inds[s::S]] = 1
else:
for s in range(S):
partitioned[s][b][vis_inds[s::S]] = 0
return partitioned
class RectangularizeMasks(nn.Module):
"""Make sure all masks in a batch have same number of 1s and 0s"""
def __init__(self, truncation_mode='min'):
super().__init__()
self._mode = truncation_mode
assert self._mode in ['min', 'max', 'mean', 'full', 'none', None], (self._mode)
def set_mode(self, mode):
self._mode = mode
def __call__(self, masks):
if self._mode in ['none', None]:
return masks
assert isinstance(masks, torch.Tensor), type(masks)
if self._mode == 'full':
return torch.ones_like(masks)
shape = masks.shape
masks = masks.flatten(1)
B, N = masks.shape
num_masked = masks.float().sum(-1)
M = {
'min': torch.amin, 'max': torch.amax, 'mean': torch.mean
}[self._mode](num_masked).long()
num_changes = num_masked.long() - M
for b in range(B):
nc = num_changes[b]
if nc > 0:
inds = torch.where(masks[b])[0]
inds = inds[torch.randperm(inds.size(0))[:nc].to(inds.device)]
masks[b, inds] = 0
elif nc < 0:
inds = torch.where(~masks[b])[0]
inds = inds[torch.randperm(inds.size(0))[:-nc].to(inds.device)]
masks[b, inds] = 1
if list(masks.shape) != list(shape):
masks = masks.view(*shape)
return masks
class UniformMaskingGenerator(object):
def __init__(self, input_size, mask_ratio, seed=None, clumping_factor=1, randomize_num_visible=False):
self.frames = None
if len(input_size) == 3:
self.frames, self.height, self.width = input_size
elif len(input_size) == 2:
self.height, self.width = input_size
elif len(input_size) == 1 or isinstance(input_size, int):
self.height = self.width = input_size
self.clumping_factor = clumping_factor
self.pad_h = self.height % self.c[0]
self.pad_w = self.width % self.c[1]
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
self.mask_ratio = mask_ratio
self.rng = np.random.RandomState(seed=seed)
self.randomize_num_visible = randomize_num_visible
@property
def num_masks_per_frame(self):
if not hasattr(self, '_num_masks_per_frame'):
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
return self._num_masks_per_frame
@num_masks_per_frame.setter
def num_masks_per_frame(self, val):
self._num_masks_per_frame = val
self._mask_ratio = (val / self.num_patches_per_frame)
@property
def c(self):
if isinstance(self.clumping_factor, int):
return (self.clumping_factor, self.clumping_factor)
else:
return self.clumping_factor[:2]
@property
def mask_ratio(self):
return self._mask_ratio
@mask_ratio.setter
def mask_ratio(self, val):
self._mask_ratio = val
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
@property
def num_visible(self):
return self.num_patches_per_frame - self.num_masks_per_frame
@num_visible.setter
def num_visible(self, val):
self.num_masks_per_frame = self.num_patches_per_frame - val
def __repr__(self):
repr_str = "Mask: total patches per frame {}, mask patches per frame {}, mask ratio {}, random num num visible? {}".format(
self.num_patches_per_frame, self.num_masks_per_frame, self.mask_ratio, self.randomize_num_visible
)
return repr_str
def sample_mask_per_frame(self):
num_masks = self.num_masks_per_frame
if self.randomize_num_visible:
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
mask = np.hstack([
np.zeros(self.num_patches_per_frame - num_masks),
np.ones(num_masks)])
self.rng.shuffle(mask)
if max(*self.c) > 1:
mask = mask.reshape(self.height // self.c[0],
1,
self.width // self.c[1],
1)
mask = np.tile(mask, (1, self.c[0], 1, self.c[1]))
mask = mask.reshape((self.height - self.pad_h, self.width - self.pad_w))
_pad_h = self.rng.choice(range(self.pad_h + 1))
pad_h = (self.pad_h - _pad_h, _pad_h)
_pad_w = self.rng.choice(range(self.pad_w + 1))
pad_w = (self.pad_w - _pad_w, _pad_w)
mask = np.pad(mask,
(pad_h, pad_w),
constant_values=1
).reshape((self.height, self.width))
return mask
def __call__(self, num_frames=None):
num_frames = (num_frames or self.frames) or 1
masks = np.stack([self.sample_mask_per_frame() for _ in range(num_frames)]).flatten()
return masks
class TubeMaskingGenerator(UniformMaskingGenerator):
def __call__(self, num_frames=None):
num_frames = (num_frames or self.frames) or 1
masks = np.tile(self.sample_mask_per_frame(), (num_frames, 1)).flatten()
return masks
class RotatedTableMaskingGenerator(TubeMaskingGenerator):
def __init__(self, tube_length=None, *args, **kwargs):
super(RotatedTableMaskingGenerator, self).__init__(*args, **kwargs)
self.tube_length = tube_length
def __call__(self, num_frames=None):
num_frames = (num_frames or self.frames) or 2
tube_length = self.tube_length or (num_frames - 1)
table_thickness = num_frames - tube_length
assert tube_length < num_frames, (tube_length, num_frames)
tubes = super().__call__(num_frames=tube_length)
top = np.zeros(table_thickness * self.height * self.width).astype(tubes.dtype).flatten()
masks = np.concatenate([top, tubes], 0)
return masks
class PytorchMaskGeneratorWrapper(nn.Module):
"""Pytorch wrapper for numpy masking generators"""
def __init__(self,
mask_generator=TubeMaskingGenerator,
*args, **kwargs):
super().__init__()
self.mask_generator = mask_generator(*args, **kwargs)
@property
def mask_ratio(self):
return self.mask_generator.mask_ratio
@mask_ratio.setter
def mask_ratio(self, value):
self.mask_generator.mask_ratio = value
def forward(self, device='cuda', dtype_out=torch.bool, **kwargs):
masks = self.mask_generator(**kwargs)
masks = torch.tensor(masks).to(device).to(dtype_out)
return masks
class MaskingGenerator(nn.Module):
"""Pytorch base class for masking generators"""
def __init__(self,
input_size,
mask_ratio,
seed=0,
visible_frames=0,
clumping_factor=1,
randomize_num_visible=False,
create_on_cpu=True,
always_batch=False):
super().__init__()
self.frames = None
if len(input_size) == 3:
self.frames, self.height, self.width = input_size
elif len(input_size) == 2:
self.height, self.width = input_size
elif len(input_size) == 1 or isinstance(input_size, int):
self.height = self.width = input_size
self.clumping_factor = clumping_factor
self.pad_h = self.height % self.c[0]
self.pad_w = self.width % self.c[1]
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1])
self.mask_ratio = mask_ratio
self.visible_frames = visible_frames
self.always_batch = always_batch
self.create_on_cpu = create_on_cpu
self.rng = np.random.RandomState(seed=seed)
self._set_torch_seed(seed)
self.randomize_num_visible = randomize_num_visible
@property
def num_masks_per_frame(self):
if not hasattr(self, '_num_masks_per_frame'):
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame)
return self._num_masks_per_frame
@num_masks_per_frame.setter
def num_masks_per_frame(self, val):
self._num_masks_per_frame = val
self._mask_ratio = (val / self.num_patches_per_frame)
@property
def c(self):
if isinstance(self.clumping_factor, int):
return (self.clumping_factor,) * 2
else:
return self.clumping_factor[:2]
@property
def mask_ratio(self):
return self._mask_ratio
@mask_ratio.setter
def mask_ratio(self, val):
self._mask_ratio = val
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame)
@property
def num_visible(self):
return self.num_patches_per_frame - self.num_masks_per_frame
@num_visible.setter
def num_visible(self, val):
self.num_masks_per_frame = self.num_patches_per_frame - val
def _set_torch_seed(self, seed):
self.seed = seed
torch.manual_seed(self.seed)
def __repr__(self):
repr_str = ("Class: {}\nMask: total patches per mask {},\n" + \
"mask patches per mask {}, visible patches per mask {}, mask ratio {:0.3f}\n" + \
"randomize num visible? {}").format(
type(self).__name__, self.num_patches_per_frame,
self.num_masks_per_frame, self.num_visible, self.mask_ratio,
self.randomize_num_visible
)
return repr_str
def sample_mask_per_frame(self, *args, **kwargs):
num_masks = self.num_masks_per_frame
if self.randomize_num_visible:
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
mask = torch.cat([
torch.zeros([self.num_patches_per_frame - num_masks]),
torch.ones([num_masks])], 0).bool()
inds = torch.randperm(mask.size(0)).long()
mask = mask[inds]
if max(*self.c) > 1:
mask = mask.view(self.height // self.c[0],
1,
self.width // self.c[1],
1)
mask = torch.tile(mask, (1, self.c[0], 1, self.c[1]))
mask = mask.reshape(self.height - self.pad_h, self.width - self.pad_w)
_pad_h = self.rng.choice(range(self.pad_h + 1))
pad_h = (self.pad_h - _pad_h, _pad_h)
_pad_w = self.rng.choice(range(self.pad_w + 1))
pad_w = (self.pad_w - _pad_w, _pad_w)
mask = F.pad(mask,
pad_w + pad_h,
mode='constant',
value=1)
mask = mask.reshape(self.height, self.width)
return mask
def forward(self, x=None, num_frames=None):
num_frames = (num_frames or self.frames) or 1
if isinstance(x, torch.Tensor):
batch_size = x.size(0)
masks = torch.stack([
torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
for b in range(batch_size)], 0)
if not self.create_on_cpu:
masks = masks.to(x.device)
if batch_size == 1 and not self.always_batch:
masks = masks.squeeze(0)
else:
batch_size = 1
masks = torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten()
if self.always_batch:
masks = masks[None]
if self.visible_frames > 0:
vis = torch.zeros((batch_size, 1, self.height, self.width), dtype=torch.bool)
vis = vis.view(masks.shape).to(masks.device)
masks = torch.cat(([vis] * self.visible_frames) + [masks], -1)
return masks