Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from torchvision import transforms | |
from cwm.data.transforms import * | |
from cwm.data.dataset import ContextAndTargetVideoDataset | |
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | |
from cwm.data.masking_generator import RotatedTableMaskingGenerator | |
class DataAugmentationForVideoMAE(object): | |
def __init__(self, augmentation_type, input_size, augmentation_scales): | |
transform_list = [] | |
self.scale = GroupScale(input_size) | |
transform_list.append(self.scale) | |
if augmentation_type == 'multiscale': | |
self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales)) | |
elif augmentation_type == 'center': | |
self.train_augmentation = GroupCenterCrop(input_size) | |
transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)]) | |
# Normalize input images | |
normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) | |
transform_list.append(normalize) | |
self.transform = transforms.Compose(transform_list) | |
def __call__(self, images): | |
process_data, _ = self.transform(images) | |
return process_data | |
def __repr__(self): | |
repr = "(DataAugmentationForVideoMAE,\n" | |
repr += " transform = %s,\n" % str(self.transform) | |
repr += ")" | |
return repr | |
def build_pretraining_dataset(args): | |
dataset_list = [] | |
data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales) | |
mask_generator = RotatedTableMaskingGenerator( | |
input_size=args.mask_input_size, | |
mask_ratio=args.mask_ratio, | |
tube_length=args.tubelet_size, | |
batch_size=args.batch_size, | |
mask_type=args.mask_type | |
) | |
for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list: | |
dataset = ContextAndTargetVideoDataset( | |
root=None, | |
setting=data_path, | |
video_ext='mp4', | |
is_color=True, | |
modality='rgb', | |
context_length=args.context_frames, | |
target_length=args.target_frames, | |
step_units=args.temporal_units, | |
new_step=args.sampling_rate, | |
context_target_gap=args.context_target_gap, | |
transform=data_transform, | |
randomize_interframes=False, | |
channels_first=True, | |
temporal_jitter=False, | |
train=True, | |
mask_generator=mask_generator, | |
) | |
dataset_list.append(dataset) | |
dataset = torch.utils.data.ConcatDataset(dataset_list) | |
return dataset | |