Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision.transforms as transforms | |
def upsample_masks(masks, size, thresh=0.5): | |
shape = masks.shape | |
dtype = masks.dtype | |
h, w = shape[-2:] | |
H, W = size | |
if (H == h) and (W == w): | |
return masks | |
elif (H < h) and (W < w): | |
s = (h // H, w // W) | |
return masks[..., ::s[0], ::s[1]] | |
masks = masks.unsqueeze(-2).unsqueeze(-1) | |
masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w) | |
if ((H % h) == 0) and ((W % w) == 0): | |
masks = masks.view(*shape[:-2], H, W) | |
else: | |
_H = np.prod(masks.shape[-4:-2]) | |
_W = np.prod(masks.shape[-2:]) | |
masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh | |
masks = masks.view(*shape[:2], H, W).to(masks.dtype) | |
return masks | |
def partition_masks(masks, num_samples=2, leave_one_out=False): | |
B = masks.shape[0] | |
S = num_samples | |
masks = masks.view(B, -1) | |
partitioned = [torch.ones_like(masks) for _ in range(S)] | |
for b in range(B): | |
vis_inds = torch.where(~masks[b])[0] | |
vis_inds = vis_inds[torch.randperm(vis_inds.size(0))] | |
if leave_one_out: | |
for s in range(S): | |
partitioned[s][b][vis_inds] = 0 | |
partitioned[s][b][vis_inds[s::S]] = 1 | |
else: | |
for s in range(S): | |
partitioned[s][b][vis_inds[s::S]] = 0 | |
return partitioned | |
class RectangularizeMasks(nn.Module): | |
"""Make sure all masks in a batch have same number of 1s and 0s""" | |
def __init__(self, truncation_mode='min'): | |
super().__init__() | |
self._mode = truncation_mode | |
assert self._mode in ['min', 'max', 'mean', 'full', 'none', None], (self._mode) | |
def set_mode(self, mode): | |
self._mode = mode | |
def __call__(self, masks): | |
if self._mode in ['none', None]: | |
return masks | |
assert isinstance(masks, torch.Tensor), type(masks) | |
if self._mode == 'full': | |
return torch.ones_like(masks) | |
shape = masks.shape | |
masks = masks.flatten(1) | |
B, N = masks.shape | |
num_masked = masks.float().sum(-1) | |
M = { | |
'min': torch.amin, 'max': torch.amax, 'mean': torch.mean | |
}[self._mode](num_masked).long() | |
num_changes = num_masked.long() - M | |
for b in range(B): | |
nc = num_changes[b] | |
if nc > 0: | |
inds = torch.where(masks[b])[0] | |
inds = inds[torch.randperm(inds.size(0))[:nc].to(inds.device)] | |
masks[b, inds] = 0 | |
elif nc < 0: | |
inds = torch.where(~masks[b])[0] | |
inds = inds[torch.randperm(inds.size(0))[:-nc].to(inds.device)] | |
masks[b, inds] = 1 | |
if list(masks.shape) != list(shape): | |
masks = masks.view(*shape) | |
return masks | |
class UniformMaskingGenerator(object): | |
def __init__(self, input_size, mask_ratio, seed=None, clumping_factor=1, randomize_num_visible=False): | |
self.frames = None | |
if len(input_size) == 3: | |
self.frames, self.height, self.width = input_size | |
elif len(input_size) == 2: | |
self.height, self.width = input_size | |
elif len(input_size) == 1 or isinstance(input_size, int): | |
self.height = self.width = input_size | |
self.clumping_factor = clumping_factor | |
self.pad_h = self.height % self.c[0] | |
self.pad_w = self.width % self.c[1] | |
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1]) | |
self.mask_ratio = mask_ratio | |
self.rng = np.random.RandomState(seed=seed) | |
self.randomize_num_visible = randomize_num_visible | |
def num_masks_per_frame(self): | |
if not hasattr(self, '_num_masks_per_frame'): | |
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame) | |
return self._num_masks_per_frame | |
def num_masks_per_frame(self, val): | |
self._num_masks_per_frame = val | |
self._mask_ratio = (val / self.num_patches_per_frame) | |
def c(self): | |
if isinstance(self.clumping_factor, int): | |
return (self.clumping_factor, self.clumping_factor) | |
else: | |
return self.clumping_factor[:2] | |
def mask_ratio(self): | |
return self._mask_ratio | |
def mask_ratio(self, val): | |
self._mask_ratio = val | |
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame) | |
def num_visible(self): | |
return self.num_patches_per_frame - self.num_masks_per_frame | |
def num_visible(self, val): | |
self.num_masks_per_frame = self.num_patches_per_frame - val | |
def __repr__(self): | |
repr_str = "Mask: total patches per frame {}, mask patches per frame {}, mask ratio {}, random num num visible? {}".format( | |
self.num_patches_per_frame, self.num_masks_per_frame, self.mask_ratio, self.randomize_num_visible | |
) | |
return repr_str | |
def sample_mask_per_frame(self): | |
num_masks = self.num_masks_per_frame | |
if self.randomize_num_visible: | |
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) | |
mask = np.hstack([ | |
np.zeros(self.num_patches_per_frame - num_masks), | |
np.ones(num_masks)]) | |
self.rng.shuffle(mask) | |
if max(*self.c) > 1: | |
mask = mask.reshape(self.height // self.c[0], | |
1, | |
self.width // self.c[1], | |
1) | |
mask = np.tile(mask, (1, self.c[0], 1, self.c[1])) | |
mask = mask.reshape((self.height - self.pad_h, self.width - self.pad_w)) | |
_pad_h = self.rng.choice(range(self.pad_h + 1)) | |
pad_h = (self.pad_h - _pad_h, _pad_h) | |
_pad_w = self.rng.choice(range(self.pad_w + 1)) | |
pad_w = (self.pad_w - _pad_w, _pad_w) | |
mask = np.pad(mask, | |
(pad_h, pad_w), | |
constant_values=1 | |
).reshape((self.height, self.width)) | |
return mask | |
def __call__(self, num_frames=None): | |
num_frames = (num_frames or self.frames) or 1 | |
masks = np.stack([self.sample_mask_per_frame() for _ in range(num_frames)]).flatten() | |
return masks | |
class TubeMaskingGenerator(UniformMaskingGenerator): | |
def __call__(self, num_frames=None): | |
num_frames = (num_frames or self.frames) or 1 | |
masks = np.tile(self.sample_mask_per_frame(), (num_frames, 1)).flatten() | |
return masks | |
class RotatedTableMaskingGenerator(TubeMaskingGenerator): | |
def __init__(self, tube_length=None, *args, **kwargs): | |
super(RotatedTableMaskingGenerator, self).__init__(*args, **kwargs) | |
self.tube_length = tube_length | |
def __call__(self, num_frames=None): | |
num_frames = (num_frames or self.frames) or 2 | |
tube_length = self.tube_length or (num_frames - 1) | |
table_thickness = num_frames - tube_length | |
assert tube_length < num_frames, (tube_length, num_frames) | |
tubes = super().__call__(num_frames=tube_length) | |
top = np.zeros(table_thickness * self.height * self.width).astype(tubes.dtype).flatten() | |
masks = np.concatenate([top, tubes], 0) | |
return masks | |
class PytorchMaskGeneratorWrapper(nn.Module): | |
"""Pytorch wrapper for numpy masking generators""" | |
def __init__(self, | |
mask_generator=TubeMaskingGenerator, | |
*args, **kwargs): | |
super().__init__() | |
self.mask_generator = mask_generator(*args, **kwargs) | |
def mask_ratio(self): | |
return self.mask_generator.mask_ratio | |
def mask_ratio(self, value): | |
self.mask_generator.mask_ratio = value | |
def forward(self, device='cuda', dtype_out=torch.bool, **kwargs): | |
masks = self.mask_generator(**kwargs) | |
masks = torch.tensor(masks).to(device).to(dtype_out) | |
return masks | |
class MaskingGenerator(nn.Module): | |
"""Pytorch base class for masking generators""" | |
def __init__(self, | |
input_size, | |
mask_ratio, | |
seed=0, | |
visible_frames=0, | |
clumping_factor=1, | |
randomize_num_visible=False, | |
create_on_cpu=True, | |
always_batch=False): | |
super().__init__() | |
self.frames = None | |
if len(input_size) == 3: | |
self.frames, self.height, self.width = input_size | |
elif len(input_size) == 2: | |
self.height, self.width = input_size | |
elif len(input_size) == 1 or isinstance(input_size, int): | |
self.height = self.width = input_size | |
self.clumping_factor = clumping_factor | |
self.pad_h = self.height % self.c[0] | |
self.pad_w = self.width % self.c[1] | |
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1]) | |
self.mask_ratio = mask_ratio | |
self.visible_frames = visible_frames | |
self.always_batch = always_batch | |
self.create_on_cpu = create_on_cpu | |
self.rng = np.random.RandomState(seed=seed) | |
self._set_torch_seed(seed) | |
self.randomize_num_visible = randomize_num_visible | |
def num_masks_per_frame(self): | |
if not hasattr(self, '_num_masks_per_frame'): | |
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame) | |
return self._num_masks_per_frame | |
def num_masks_per_frame(self, val): | |
self._num_masks_per_frame = val | |
self._mask_ratio = (val / self.num_patches_per_frame) | |
def c(self): | |
if isinstance(self.clumping_factor, int): | |
return (self.clumping_factor,) * 2 | |
else: | |
return self.clumping_factor[:2] | |
def mask_ratio(self): | |
return self._mask_ratio | |
def mask_ratio(self, val): | |
self._mask_ratio = val | |
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame) | |
def num_visible(self): | |
return self.num_patches_per_frame - self.num_masks_per_frame | |
def num_visible(self, val): | |
self.num_masks_per_frame = self.num_patches_per_frame - val | |
def _set_torch_seed(self, seed): | |
self.seed = seed | |
torch.manual_seed(self.seed) | |
def __repr__(self): | |
repr_str = ("Class: {}\nMask: total patches per mask {},\n" + \ | |
"mask patches per mask {}, visible patches per mask {}, mask ratio {:0.3f}\n" + \ | |
"randomize num visible? {}").format( | |
type(self).__name__, self.num_patches_per_frame, | |
self.num_masks_per_frame, self.num_visible, self.mask_ratio, | |
self.randomize_num_visible | |
) | |
return repr_str | |
def sample_mask_per_frame(self, *args, **kwargs): | |
num_masks = self.num_masks_per_frame | |
if self.randomize_num_visible: | |
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) | |
mask = torch.cat([ | |
torch.zeros([self.num_patches_per_frame - num_masks]), | |
torch.ones([num_masks])], 0).bool() | |
inds = torch.randperm(mask.size(0)).long() | |
mask = mask[inds] | |
if max(*self.c) > 1: | |
mask = mask.view(self.height // self.c[0], | |
1, | |
self.width // self.c[1], | |
1) | |
mask = torch.tile(mask, (1, self.c[0], 1, self.c[1])) | |
mask = mask.reshape(self.height - self.pad_h, self.width - self.pad_w) | |
_pad_h = self.rng.choice(range(self.pad_h + 1)) | |
pad_h = (self.pad_h - _pad_h, _pad_h) | |
_pad_w = self.rng.choice(range(self.pad_w + 1)) | |
pad_w = (self.pad_w - _pad_w, _pad_w) | |
mask = F.pad(mask, | |
pad_w + pad_h, | |
mode='constant', | |
value=1) | |
mask = mask.reshape(self.height, self.width) | |
return mask | |
def forward(self, x=None, num_frames=None): | |
num_frames = (num_frames or self.frames) or 1 | |
if isinstance(x, torch.Tensor): | |
batch_size = x.size(0) | |
masks = torch.stack([ | |
torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten() | |
for b in range(batch_size)], 0) | |
if not self.create_on_cpu: | |
masks = masks.to(x.device) | |
if batch_size == 1 and not self.always_batch: | |
masks = masks.squeeze(0) | |
else: | |
batch_size = 1 | |
masks = torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten() | |
if self.always_batch: | |
masks = masks[None] | |
if self.visible_frames > 0: | |
vis = torch.zeros((batch_size, 1, self.height, self.width), dtype=torch.bool) | |
vis = vis.view(masks.shape).to(masks.device) | |
masks = torch.cat(([vis] * self.visible_frames) + [masks], -1) | |
return masks | |