rahulvenkk
app.py updated
6dfcb0f
raw
history blame
2.65 kB
from torchvision import transforms
from cwm.data.transforms import *
from cwm.data.dataset import ContextAndTargetVideoDataset
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from cwm.data.masking_generator import RotatedTableMaskingGenerator
class DataAugmentationForVideoMAE(object):
def __init__(self, augmentation_type, input_size, augmentation_scales):
transform_list = []
self.scale = GroupScale(input_size)
transform_list.append(self.scale)
if augmentation_type == 'multiscale':
self.train_augmentation = GroupMultiScaleCrop(input_size, list(augmentation_scales))
elif augmentation_type == 'center':
self.train_augmentation = GroupCenterCrop(input_size)
transform_list.extend([self.train_augmentation, Stack(roll=False), ToTorchFormatTensor(div=True)])
# Normalize input images
normalize = GroupNormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
transform_list.append(normalize)
self.transform = transforms.Compose(transform_list)
def __call__(self, images):
process_data, _ = self.transform(images)
return process_data
def __repr__(self):
repr = "(DataAugmentationForVideoMAE,\n"
repr += " transform = %s,\n" % str(self.transform)
repr += ")"
return repr
def build_pretraining_dataset(args):
dataset_list = []
data_transform = DataAugmentationForVideoMAE(args.augmentation_type, args.input_size, args.augmentation_scales)
mask_generator = RotatedTableMaskingGenerator(
input_size=args.mask_input_size,
mask_ratio=args.mask_ratio,
tube_length=args.tubelet_size,
batch_size=args.batch_size,
mask_type=args.mask_type
)
for data_path in [args.data_path] if args.data_path_list is None else args.data_path_list:
dataset = ContextAndTargetVideoDataset(
root=None,
setting=data_path,
video_ext='mp4',
is_color=True,
modality='rgb',
context_length=args.context_frames,
target_length=args.target_frames,
step_units=args.temporal_units,
new_step=args.sampling_rate,
context_target_gap=args.context_target_gap,
transform=data_transform,
randomize_interframes=False,
channels_first=True,
temporal_jitter=False,
train=True,
mask_generator=mask_generator,
)
dataset_list.append(dataset)
dataset = torch.utils.data.ConcatDataset(dataset_list)
return dataset