File size: 2,646 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from torchvision import transforms
from cwm.data.transforms import *
from cwm.data.dataset import ContextAndTargetVideoDataset
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from cwm.data.masking_generator import RotatedTableMaskingGenerator

class DataAugmentationForVideoMAE(object):
    def __init__(self, augmentation_type, input_size, augmentation_scales):

        transform_list = []

        self.scale = GroupScale(input_size)
        transform_list.append(self.scale)

        if augmentation_type == 'multiscale':
            self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales))
        elif augmentation_type == 'center':
            self.train_augmentation = GroupCenterCrop(input_size)

        transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)])

        # Normalize input images
        normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
        transform_list.append(normalize)

        self.transform = transforms.Compose(transform_list)

    def __call__(self, images):
        process_data, _ = self.transform(images)
        return process_data

    def __repr__(self):
        repr = "(DataAugmentationForVideoMAE,\n"
        repr += "  transform = %s,\n" % str(self.transform)
        repr += ")"
        return repr


def build_pretraining_dataset(args):

    dataset_list = []
    data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales)

    mask_generator = RotatedTableMaskingGenerator(
        input_size=args.mask_input_size,
        mask_ratio=args.mask_ratio,
        tube_length=args.tubelet_size,
        batch_size=args.batch_size,
        mask_type=args.mask_type
    )

    for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list:
        dataset = ContextAndTargetVideoDataset(
            root=None,
            setting=data_path,
            video_ext='mp4',
            is_color=True,
            modality='rgb',
            context_length=args.context_frames,
            target_length=args.target_frames,
            step_units=args.temporal_units,
            new_step=args.sampling_rate,
            context_target_gap=args.context_target_gap,
            transform=data_transform,
            randomize_interframes=False,
            channels_first=True,
            temporal_jitter=False,
            train=True,
            mask_generator=mask_generator,
        )
        dataset_list.append(dataset)
    dataset = torch.utils.data.ConcatDataset(dataset_list)
    return dataset