File size: 3,021 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
import torch

def get_tubes(masks_per_frame, tube_length):
    rp = torch.randperm(len(masks_per_frame))
    masks_per_frame = masks_per_frame[rp]

    tubes = [masks_per_frame]
    for x in range(tube_length - 1):
        masks_per_frame = masks_per_frame.clone()
        rp = torch.randperm(len(masks_per_frame))
        masks_per_frame = masks_per_frame[rp]
        tubes.append(masks_per_frame)

    tubes = torch.vstack(tubes)

    return tubes

class RotatedTableMaskingGenerator:
    def __init__(self,
                 input_size,
                 mask_ratio,
                 tube_length,
                 batch_size,
                 mask_type='rotated_table',
                 seed=None,
                 randomize_num_visible=False):

        self.batch_size = batch_size

        self.mask_ratio = mask_ratio
        self.tube_length = tube_length

        self.frames, self.height, self.width = input_size
        self.num_patches_per_frame = self.height * self.width
        self.total_patches = self.frames * self.num_patches_per_frame

        self.seed = seed
        self.randomize_num_visible = randomize_num_visible

        self.mask_type = mask_type

    def __repr__(self):
        repr_str = "Inverted Table Mask: total patches {}, tube length {}, randomize num visible? {}, seed {}".format(
            self.total_patches, self.tube_length, self.randomize_num_visible, self.seed
        )
        return repr_str

    def __call__(self, m=None):

        if self.mask_type == 'rotated_table_magvit':
            self.mask_ratio = np.random.uniform(low=0.0, high=1)
            self.mask_ratio = np.cos(self.mask_ratio * np.pi / 2)
        elif self.mask_type == 'rotated_table_maskvit':
            self.mask_ratio = np.random.uniform(low=0.5, high=1)

        all_masks = []
        for b in range(self.batch_size):

            self.num_masks_per_frame = max(0, int(self.mask_ratio * self.num_patches_per_frame))
            self.total_masks = self.tube_length * self.num_masks_per_frame

            num_masks = self.num_masks_per_frame

            if self.randomize_num_visible:
                assert "Randomize num visible Not implemented"
                num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))

            if self.mask_ratio == 0:
                mask_per_frame = torch.hstack([
                torch.zeros(self.num_patches_per_frame - num_masks),
            ])
            else:
                mask_per_frame = torch.hstack([
                    torch.zeros(self.num_patches_per_frame - num_masks),
                    torch.ones(num_masks),
                ])

            tubes = get_tubes(mask_per_frame, self.tube_length)
            top = torch.zeros(self.height * self.width).to(tubes.dtype)

            top = torch.tile(top, (self.frames - self.tube_length, 1))
            mask = torch.cat([top, tubes])
            mask = mask.flatten()
            all_masks.append(mask)
        return torch.stack(all_masks)