import numpy as np import torch def get_tubes(masks_per_frame, tube_length): rp = torch.randperm(len(masks_per_frame)) masks_per_frame = masks_per_frame[rp] tubes = [masks_per_frame] for x in range(tube_length - 1): masks_per_frame = masks_per_frame.clone() rp = torch.randperm(len(masks_per_frame)) masks_per_frame = masks_per_frame[rp] tubes.append(masks_per_frame) tubes = torch.vstack(tubes) return tubes class RotatedTableMaskingGenerator: def __init__(self, input_size, mask_ratio, tube_length, batch_size, mask_type='rotated_table', seed=None, randomize_num_visible=False): self.batch_size = batch_size self.mask_ratio = mask_ratio self.tube_length = tube_length self.frames, self.height, self.width = input_size self.num_patches_per_frame = self.height * self.width self.total_patches = self.frames * self.num_patches_per_frame self.seed = seed self.randomize_num_visible = randomize_num_visible self.mask_type = mask_type def __repr__(self): repr_str = "Inverted Table Mask: total patches {}, tube length {}, randomize num visible? {}, seed {}".format( self.total_patches, self.tube_length, self.randomize_num_visible, self.seed ) return repr_str def __call__(self, m=None): if self.mask_type == 'rotated_table_magvit': self.mask_ratio = np.random.uniform(low=0.0, high=1) self.mask_ratio = np.cos(self.mask_ratio * np.pi / 2) elif self.mask_type == 'rotated_table_maskvit': self.mask_ratio = np.random.uniform(low=0.5, high=1) all_masks = [] for b in range(self.batch_size): self.num_masks_per_frame = max(0, int(self.mask_ratio * self.num_patches_per_frame)) self.total_masks = self.tube_length * self.num_masks_per_frame num_masks = self.num_masks_per_frame if self.randomize_num_visible: assert "Randomize num visible Not implemented" num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) if self.mask_ratio == 0: mask_per_frame = torch.hstack([ torch.zeros(self.num_patches_per_frame - num_masks), ]) else: mask_per_frame = torch.hstack([ torch.zeros(self.num_patches_per_frame - num_masks), torch.ones(num_masks), ]) tubes = get_tubes(mask_per_frame, self.tube_length) top = torch.zeros(self.height * self.width).to(tubes.dtype) top = torch.tile(top, (self.frames - self.tube_length, 1)) mask = torch.cat([top, tubes]) mask = mask.flatten() all_masks.append(mask) return torch.stack(all_masks)