import os import decord import numpy as np import torch from PIL import Image from torch.utils.data import Dataset from torchvision import transforms class VideoMAE(torch.utils.data.Dataset): """Load your own video classification dataset. Parameters ---------- root : str, required. Path to the root folder storing the dataset. setting : str, required. A text file describing the dataset, each line per video sample. There are three items in each line: (1) video path; (2) video length and (3) video label. train : bool, default True. Whether to load the training or validation set. test_mode : bool, default False. Whether to perform evaluation on the test set. Usually there is three-crop or ten-crop evaluation strategy involved. name_pattern : str, default None. The naming pattern of the decoded video frames. For example, img_00012.jpg. video_ext : str, default 'mp4'. If video_loader is set to True, please specify the video format accordinly. is_color : bool, default True. Whether the loaded image is color or grayscale. modality : str, default 'rgb'. Input modalities, we support only rgb video frames for now. Will add support for rgb difference image and optical flow image later. num_segments : int, default 1. Number of segments to evenly divide the video into clips. A useful technique to obtain global video-level information. Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. num_crop : int, default 1. Number of crops for each image. default is 1. Common choices are three crops and ten crops during evaluation. new_length : int, default 1. The length of input video clip. Default is a single image, but it can be multiple video frames. For example, new_length=16 means we will extract a video clip of consecutive 16 frames. new_step : int, default 1. Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. new_step=2 means we will extract a video clip of every other frame. temporal_jitter : bool, default False. Whether to temporally jitter if new_step > 1. video_loader : bool, default False. Whether to use video loader to load data. use_decord : bool, default True. Whether to use Decord video loader to load data. Otherwise use mmcv video loader. transform : function, default None. A function that takes data and label and transforms them. data_aug : str, default 'v1'. Different types of data augmentation auto. Supports v1, v2, v3 and v4. lazy_init : bool, default False. If set to True, build a dataset instance without loading any dataset. """ def __init__(self, root, setting, train=True, test_mode=False, name_pattern='img_%05d.jpg', video_ext='mp4', is_color=True, modality='rgb', num_segments=1, num_crop=1, new_length=1, new_step=1, randomize_interframes=False, transform=None, temporal_jitter=False, video_loader=False, use_decord=False, lazy_init=False, is_video_dataset=True): super(VideoMAE, self).__init__() self.root = root self.setting = setting self.train = train self.test_mode = test_mode self.is_color = is_color self.modality = modality self.num_segments = num_segments self.num_crop = num_crop self.new_length = new_length self.randomize_interframes = randomize_interframes self._new_step = new_step # If randomize_interframes is True, then this is the max, otherwise it's just the skip # self._skip_length = self.new_length * self.new_step # If randomize_interframes is True, then this isn't used, otherwise it's used as calculated self.temporal_jitter = temporal_jitter self.name_pattern = name_pattern self.video_loader = video_loader self.video_ext = video_ext self.use_decord = use_decord self.transform = transform self.lazy_init = lazy_init if (not self.lazy_init) and is_video_dataset: self.clips = self._make_dataset(root, setting) if len(self.clips) == 0: raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" "Check your data directory (opt.data-dir).")) def __getitem__(self, index): directory, target = self.clips[index] if self.video_loader: if '.' in directory.split('/')[-1]: # data in the "setting" file already have extension, e.g., demo.mp4 video_name = directory else: # data in the "setting" file do not have extension, e.g., demo # So we need to provide extension (i.e., .mp4) to complete the file name. video_name = '{}.{}'.format(directory, self.video_ext) try: decord_vr = decord.VideoReader(video_name, num_threads=1) except: # return video_name return (self.__getitem__(index + 1)) duration = len(decord_vr) segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration) images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets, new_step, skip_length) process_data, mask = self.transform((images, None)) # T*C,H,W process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, 1) # T*C,H,W -> T,C,H,W -> C,T,H,W return (process_data, mask) def __len__(self): return len(self.clips) def _make_dataset(self, directory, setting): if not os.path.exists(setting): raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) clips = [] with open(setting) as split_f: data = split_f.readlines() for line in data: line_info = line.split(' ') # line format: video_path, video_duration, video_label if len(line_info) < 2: raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) elif len(line_info) > 2: line_info = (' '.join(line_info[:-1]), line_info[-1]) # filename has spaces clip_path = os.path.join(line_info[0]) target = int(line_info[1]) item = (clip_path, target) clips.append(item) # import torch_xla.core.xla_model as xm # print = xm.master_print # print("Dataset created. Number of clips: ", len(clips)) return clips def _sample_train_indices(self, num_frames): if self.randomize_interframes is False: new_step = self._new_step else: new_step = np.random.randint(1, self._new_step + 1) skip_length = self.new_length * new_step average_duration = (num_frames - skip_length + 1) // self.num_segments if average_duration > 0: offsets = np.multiply(list(range(self.num_segments)), average_duration) offsets = offsets + np.random.randint(average_duration, size=self.num_segments) elif num_frames > max(self.num_segments, skip_length): offsets = np.sort(np.random.randint( num_frames - skip_length + 1, size=self.num_segments)) else: offsets = np.zeros((self.num_segments,)) if self.temporal_jitter: skip_offsets = np.random.randint( new_step, size=skip_length // new_step) else: skip_offsets = np.zeros( skip_length // new_step, dtype=int) return offsets + 1, skip_offsets, new_step, skip_length def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step, skip_length): sampled_list = [] frame_id_list = [] for seg_ind in indices: offset = int(seg_ind) for i, _ in enumerate(range(0, skip_length, new_step)): if offset + skip_offsets[i] <= duration: frame_id = offset + skip_offsets[i] - 1 else: frame_id = offset - 1 frame_id_list.append(frame_id) if offset + new_step < duration: offset += new_step try: video_data = video_reader.get_batch(frame_id_list).asnumpy() sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in enumerate(frame_id_list)] except: raise RuntimeError( 'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, duration)) return sampled_list class ContextAndTargetVideoDataset(VideoMAE): """ A video dataset whose provided videos consist of (1) a "context" sequence of length Tc and (2) a "target" sequence Tt. These two sequences have the same frame rate (specificiable in real units) but are separated by a specified gap (which may vary for different examples.) The main use case is for training models to predict ahead by some variable amount, given the context. """ standard_fps = [12, 24, 30, 48, 60, 100] def __init__(self, root, setting, train=True, test_mode=False, transform=None, step_units='ms', new_step=150, start_frame=0, context_length=2, target_length=1, channels_first=True, generate_masks=True, mask_generator=None, context_target_gap=[400, 600], normalize_timestamps=True, default_fps=30, min_fps=0.1, seed=0, *args, **kwargs): super(ContextAndTargetVideoDataset, self).__init__( root=root, setting=setting, train=train, test_mode=test_mode, transform=transform, new_length=context_length, use_decord=True, lazy_init=False, video_loader=True, *args, **kwargs) # breakpoint() self.context_length = self.new_length self.target_length = target_length ## convert from fps and step size to frames self._fps = None self._min_fps = min_fps self._default_fps = default_fps self._step_units = step_units self.new_step = new_step ## sampling for train and test self._start_frame = start_frame self.gap = context_target_gap self.seed = seed self.rng = np.random.RandomState(seed=seed) # breakpoint() ## output formatting self._channels_first = channels_first self._normalize_timestamps = normalize_timestamps self._generate_masks = generate_masks self.mask_generator = mask_generator def _get_frames_per_t(self, t): if self._step_units == 'frames' or (self._step_units is None): return int(t) assert self._fps is not None t_per_frame = 1 / self._fps if self._step_units in ['ms', 'milliseconds']: t_per_frame *= 1000.0 return max(int(np.round(t / t_per_frame)), 1) @property def new_step(self): if self._fps is None: return None else: return self._get_frames_per_t(self._new_step) @new_step.setter def new_step(self, v): self._new_step = v @property def gap(self): if self._fps is None: return [1, 2] else: gap = [self._get_frames_per_t(self._gap[0]), self._get_frames_per_t(self._gap[1])] gap[1] = max(gap[1], gap[0] + 1) return gap @gap.setter def gap(self, v): if v is None: v = self._new_step if not isinstance(v, (list, tuple)): v = [v, v] self._gap = v def _get_video_name(self, directory): if ''.join(['.', self.video_ext]) in directory.split('/')[-1]: # data in the "setting" file has extension, e.g. demo.mpr video_name = directory else: # data doesn't have an extension video_name = '{}.{}'.format(directory, self.video_ext) return video_name def _set_fps(self, reader): """click fps to a standard""" if self._step_units == 'frames' or self._step_units is None: self._fps = None else: self._fps = None fps = reader.get_avg_fps() for st in self.standard_fps: if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st): self._fps = st if self._fps is None: self._fps = int(np.round(fps)) if self._fps < self._min_fps: self._fps = self._default_fps def _get_step_and_gap(self): step = self.new_step if self.randomize_interframes and self.train: step = self.rng.randint(1, step + 1) if self.train: gap = self.rng.randint(*self.gap) else: gap = sum(self.gap) // 2 return (step, gap) def _sample_frames(self): step, gap = self._get_step_and_gap() ## compute total length of sample ## e.g. if context_length = 2, step = 1, gap = 10, target_length = 2: ## total_length = 2 * 1 + 10 + (2 - 1) * 1 = 13 ## so len(video) must be >= 13 self._total_length = self.context_length * step + gap + (self.target_length - 1) * step if self._total_length > (self._num_frames - self._start_frame): if self.train: return None else: raise ValueError( "movie of length %d starting at fr=%d is too long for video of %d frames" % \ (self._total_length, self._start_frame, self._num_frames)) ## sample the frames randomly (if training) or from the start frame (if test) if self.train: self.start_frame_now = self.rng.randint( min(self._start_frame, self._num_frames - self._total_length), self._num_frames - self._total_length + 1) else: self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length) frames = [self.start_frame_now + i * step for i in range(self.context_length)] frames += [frames[-1] + gap + i * step for i in range(self.target_length)] # breakpoint() return frames def _decode_frame_images(self, reader, frames): try: video_data = reader.get_batch(frames).asnumpy() video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB') for t, _ in enumerate(frames)] except: raise RuntimeError( "Error occurred in reading frames {} from video {} of duration {}".format( frames, self.index, self._num_frames)) return video_data def __getitem__(self, index): self.index = index self.directory, target = self.clips[index] self.video_name = self._get_video_name(self.directory) ## build decord loader try: decord_vr = decord.VideoReader(self.video_name, num_threads=1) self._set_fps(decord_vr) except: # return self.video_name return (self.__getitem__(index + 1)) ## sample the video self._num_frames = len(decord_vr) self.frames = self._sample_frames() if self.frames is None: print("no movie of length %d for video idx=%d" % (self._total_length, self.index)) return self.__getitem__(index + 1) ## decode to PIL.Image image_list = self._decode_frame_images(decord_vr, self.frames) ## postproc to torch.Tensor and mask generation if self.transform is None: image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0) else: image_tensor = self.transform((image_list, None)) image_tensor = image_tensor.view(self.context_length + self.target_length, 3, *image_tensor.shape[-2:]) ## VMAE expects [B,C,T,H,W] rather than [B,T,C,H,W] if self._channels_first: image_tensor = image_tensor.transpose(0, 1) if self._generate_masks and self.mask_generator is not None: mask = self.mask_generator() return image_tensor, mask.bool() else: return image_tensor