Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import decord | |
import numpy as np | |
import torch | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
class VideoMAE(torch.utils.data.Dataset): | |
"""Load your own video classification dataset. | |
Parameters | |
---------- | |
root : str, required. | |
Path to the root folder storing the dataset. | |
setting : str, required. | |
A text file describing the dataset, each line per video sample. | |
There are three items in each line: (1) video path; (2) video length and (3) video label. | |
train : bool, default True. | |
Whether to load the training or validation set. | |
test_mode : bool, default False. | |
Whether to perform evaluation on the test set. | |
Usually there is three-crop or ten-crop evaluation strategy involved. | |
name_pattern : str, default None. | |
The naming pattern of the decoded video frames. | |
For example, img_00012.jpg. | |
video_ext : str, default 'mp4'. | |
If video_loader is set to True, please specify the video format accordinly. | |
is_color : bool, default True. | |
Whether the loaded image is color or grayscale. | |
modality : str, default 'rgb'. | |
Input modalities, we support only rgb video frames for now. | |
Will add support for rgb difference image and optical flow image later. | |
num_segments : int, default 1. | |
Number of segments to evenly divide the video into clips. | |
A useful technique to obtain global video-level information. | |
Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016. | |
num_crop : int, default 1. | |
Number of crops for each image. default is 1. | |
Common choices are three crops and ten crops during evaluation. | |
new_length : int, default 1. | |
The length of input video clip. Default is a single image, but it can be multiple video frames. | |
For example, new_length=16 means we will extract a video clip of consecutive 16 frames. | |
new_step : int, default 1. | |
Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames. | |
new_step=2 means we will extract a video clip of every other frame. | |
temporal_jitter : bool, default False. | |
Whether to temporally jitter if new_step > 1. | |
video_loader : bool, default False. | |
Whether to use video loader to load data. | |
use_decord : bool, default True. | |
Whether to use Decord video loader to load data. Otherwise use mmcv video loader. | |
transform : function, default None. | |
A function that takes data and label and transforms them. | |
data_aug : str, default 'v1'. | |
Different types of data augmentation auto. Supports v1, v2, v3 and v4. | |
lazy_init : bool, default False. | |
If set to True, build a dataset instance without loading any dataset. | |
""" | |
def __init__(self, | |
root, | |
setting, | |
train=True, | |
test_mode=False, | |
name_pattern='img_%05d.jpg', | |
video_ext='mp4', | |
is_color=True, | |
modality='rgb', | |
num_segments=1, | |
num_crop=1, | |
new_length=1, | |
new_step=1, | |
randomize_interframes=False, | |
transform=None, | |
temporal_jitter=False, | |
video_loader=False, | |
use_decord=False, | |
lazy_init=False, | |
is_video_dataset=True): | |
super(VideoMAE, self).__init__() | |
self.root = root | |
self.setting = setting | |
self.train = train | |
self.test_mode = test_mode | |
self.is_color = is_color | |
self.modality = modality | |
self.num_segments = num_segments | |
self.num_crop = num_crop | |
self.new_length = new_length | |
self.randomize_interframes = randomize_interframes | |
self._new_step = new_step # If randomize_interframes is True, then this is the max, otherwise it's just the skip | |
# self._skip_length = self.new_length * self.new_step # If randomize_interframes is True, then this isn't used, otherwise it's used as calculated | |
self.temporal_jitter = temporal_jitter | |
self.name_pattern = name_pattern | |
self.video_loader = video_loader | |
self.video_ext = video_ext | |
self.use_decord = use_decord | |
self.transform = transform | |
self.lazy_init = lazy_init | |
if (not self.lazy_init) and is_video_dataset: | |
self.clips = self._make_dataset(root, setting) | |
if len(self.clips) == 0: | |
raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n" | |
"Check your data directory (opt.data-dir).")) | |
def __getitem__(self, index): | |
directory, target = self.clips[index] | |
if self.video_loader: | |
if '.' in directory.split('/')[-1]: | |
# data in the "setting" file already have extension, e.g., demo.mp4 | |
video_name = directory | |
else: | |
# data in the "setting" file do not have extension, e.g., demo | |
# So we need to provide extension (i.e., .mp4) to complete the file name. | |
video_name = '{}.{}'.format(directory, self.video_ext) | |
try: | |
decord_vr = decord.VideoReader(video_name, num_threads=1) | |
except: | |
# return video_name | |
return (self.__getitem__(index + 1)) | |
duration = len(decord_vr) | |
segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration) | |
images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets, | |
new_step, skip_length) | |
process_data, mask = self.transform((images, None)) # T*C,H,W | |
process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0, | |
1) # T*C,H,W -> T,C,H,W -> C,T,H,W | |
return (process_data, mask) | |
def __len__(self): | |
return len(self.clips) | |
def _make_dataset(self, directory, setting): | |
if not os.path.exists(setting): | |
raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting))) | |
clips = [] | |
with open(setting) as split_f: | |
data = split_f.readlines() | |
for line in data: | |
line_info = line.split(' ') | |
# line format: video_path, video_duration, video_label | |
if len(line_info) < 2: | |
raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line)) | |
elif len(line_info) > 2: | |
line_info = (' '.join(line_info[:-1]), line_info[-1]) # filename has spaces | |
clip_path = os.path.join(line_info[0]) | |
target = int(line_info[1]) | |
item = (clip_path, target) | |
clips.append(item) | |
# import torch_xla.core.xla_model as xm | |
# print = xm.master_print | |
# print("Dataset created. Number of clips: ", len(clips)) | |
return clips | |
def _sample_train_indices(self, num_frames): | |
if self.randomize_interframes is False: | |
new_step = self._new_step | |
else: | |
new_step = np.random.randint(1, self._new_step + 1) | |
skip_length = self.new_length * new_step | |
average_duration = (num_frames - skip_length + 1) // self.num_segments | |
if average_duration > 0: | |
offsets = np.multiply(list(range(self.num_segments)), | |
average_duration) | |
offsets = offsets + np.random.randint(average_duration, | |
size=self.num_segments) | |
elif num_frames > max(self.num_segments, skip_length): | |
offsets = np.sort(np.random.randint( | |
num_frames - skip_length + 1, | |
size=self.num_segments)) | |
else: | |
offsets = np.zeros((self.num_segments,)) | |
if self.temporal_jitter: | |
skip_offsets = np.random.randint( | |
new_step, size=skip_length // new_step) | |
else: | |
skip_offsets = np.zeros( | |
skip_length // new_step, dtype=int) | |
return offsets + 1, skip_offsets, new_step, skip_length | |
def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step, | |
skip_length): | |
sampled_list = [] | |
frame_id_list = [] | |
for seg_ind in indices: | |
offset = int(seg_ind) | |
for i, _ in enumerate(range(0, skip_length, new_step)): | |
if offset + skip_offsets[i] <= duration: | |
frame_id = offset + skip_offsets[i] - 1 | |
else: | |
frame_id = offset - 1 | |
frame_id_list.append(frame_id) | |
if offset + new_step < duration: | |
offset += new_step | |
try: | |
video_data = video_reader.get_batch(frame_id_list).asnumpy() | |
sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in | |
enumerate(frame_id_list)] | |
except: | |
raise RuntimeError( | |
'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory, | |
duration)) | |
return sampled_list | |
class ContextAndTargetVideoDataset(VideoMAE): | |
""" | |
A video dataset whose provided videos consist of (1) a "context" sequence of length Tc | |
and (2) a "target" sequence Tt. | |
These two sequences have the same frame rate (specificiable in real units) but are | |
separated by a specified gap (which may vary for different examples.) | |
The main use case is for training models to predict ahead by some variable amount, | |
given the context. | |
""" | |
standard_fps = [12, 24, 30, 48, 60, 100] | |
def __init__(self, | |
root, | |
setting, | |
train=True, | |
test_mode=False, | |
transform=None, | |
step_units='ms', | |
new_step=150, | |
start_frame=0, | |
context_length=2, | |
target_length=1, | |
channels_first=True, | |
generate_masks=True, | |
mask_generator=None, | |
context_target_gap=[400, 600], | |
normalize_timestamps=True, | |
default_fps=30, | |
min_fps=0.1, | |
seed=0, | |
*args, | |
**kwargs): | |
super(ContextAndTargetVideoDataset, self).__init__( | |
root=root, | |
setting=setting, | |
train=train, | |
test_mode=test_mode, | |
transform=transform, | |
new_length=context_length, | |
use_decord=True, | |
lazy_init=False, | |
video_loader=True, | |
*args, **kwargs) | |
# breakpoint() | |
self.context_length = self.new_length | |
self.target_length = target_length | |
## convert from fps and step size to frames | |
self._fps = None | |
self._min_fps = min_fps | |
self._default_fps = default_fps | |
self._step_units = step_units | |
self.new_step = new_step | |
## sampling for train and test | |
self._start_frame = start_frame | |
self.gap = context_target_gap | |
self.seed = seed | |
self.rng = np.random.RandomState(seed=seed) | |
# breakpoint() | |
## output formatting | |
self._channels_first = channels_first | |
self._normalize_timestamps = normalize_timestamps | |
self._generate_masks = generate_masks | |
self.mask_generator = mask_generator | |
def _get_frames_per_t(self, t): | |
if self._step_units == 'frames' or (self._step_units is None): | |
return int(t) | |
assert self._fps is not None | |
t_per_frame = 1 / self._fps | |
if self._step_units in ['ms', 'milliseconds']: | |
t_per_frame *= 1000.0 | |
return max(int(np.round(t / t_per_frame)), 1) | |
def new_step(self): | |
if self._fps is None: | |
return None | |
else: | |
return self._get_frames_per_t(self._new_step) | |
def new_step(self, v): | |
self._new_step = v | |
def gap(self): | |
if self._fps is None: | |
return [1, 2] | |
else: | |
gap = [self._get_frames_per_t(self._gap[0]), | |
self._get_frames_per_t(self._gap[1])] | |
gap[1] = max(gap[1], gap[0] + 1) | |
return gap | |
def gap(self, v): | |
if v is None: | |
v = self._new_step | |
if not isinstance(v, (list, tuple)): | |
v = [v, v] | |
self._gap = v | |
def _get_video_name(self, directory): | |
if ''.join(['.', self.video_ext]) in directory.split('/')[-1]: | |
# data in the "setting" file has extension, e.g. demo.mpr | |
video_name = directory | |
else: | |
# data doesn't have an extension | |
video_name = '{}.{}'.format(directory, self.video_ext) | |
return video_name | |
def _set_fps(self, reader): | |
"""click fps to a standard""" | |
if self._step_units == 'frames' or self._step_units is None: | |
self._fps = None | |
else: | |
self._fps = None | |
fps = reader.get_avg_fps() | |
for st in self.standard_fps: | |
if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st): | |
self._fps = st | |
if self._fps is None: | |
self._fps = int(np.round(fps)) | |
if self._fps < self._min_fps: | |
self._fps = self._default_fps | |
def _get_step_and_gap(self): | |
step = self.new_step | |
if self.randomize_interframes and self.train: | |
step = self.rng.randint(1, step + 1) | |
if self.train: | |
gap = self.rng.randint(*self.gap) | |
else: | |
gap = sum(self.gap) // 2 | |
return (step, gap) | |
def _sample_frames(self): | |
step, gap = self._get_step_and_gap() | |
## compute total length of sample | |
## e.g. if context_length = 2, step = 1, gap = 10, target_length = 2: | |
## total_length = 2 * 1 + 10 + (2 - 1) * 1 = 13 | |
## so len(video) must be >= 13 | |
self._total_length = self.context_length * step + gap + (self.target_length - 1) * step | |
if self._total_length > (self._num_frames - self._start_frame): | |
if self.train: | |
return None | |
else: | |
raise ValueError( | |
"movie of length %d starting at fr=%d is too long for video of %d frames" % \ | |
(self._total_length, self._start_frame, self._num_frames)) | |
## sample the frames randomly (if training) or from the start frame (if test) | |
if self.train: | |
self.start_frame_now = self.rng.randint( | |
min(self._start_frame, self._num_frames - self._total_length), | |
self._num_frames - self._total_length + 1) | |
else: | |
self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length) | |
frames = [self.start_frame_now + i * step for i in range(self.context_length)] | |
frames += [frames[-1] + gap + i * step for i in range(self.target_length)] | |
# breakpoint() | |
return frames | |
def _decode_frame_images(self, reader, frames): | |
try: | |
video_data = reader.get_batch(frames).asnumpy() | |
video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB') | |
for t, _ in enumerate(frames)] | |
except: | |
raise RuntimeError( | |
"Error occurred in reading frames {} from video {} of duration {}".format( | |
frames, self.index, self._num_frames)) | |
return video_data | |
def __getitem__(self, index): | |
self.index = index | |
self.directory, target = self.clips[index] | |
self.video_name = self._get_video_name(self.directory) | |
## build decord loader | |
try: | |
decord_vr = decord.VideoReader(self.video_name, num_threads=1) | |
self._set_fps(decord_vr) | |
except: | |
# return self.video_name | |
return (self.__getitem__(index + 1)) | |
## sample the video | |
self._num_frames = len(decord_vr) | |
self.frames = self._sample_frames() | |
if self.frames is None: | |
print("no movie of length %d for video idx=%d" % (self._total_length, self.index)) | |
return self.__getitem__(index + 1) | |
## decode to PIL.Image | |
image_list = self._decode_frame_images(decord_vr, self.frames) | |
## postproc to torch.Tensor and mask generation | |
if self.transform is None: | |
image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0) | |
else: | |
image_tensor = self.transform((image_list, None)) | |
image_tensor = image_tensor.view(self.context_length + self.target_length, 3, *image_tensor.shape[-2:]) | |
## VMAE expects [B,C,T,H,W] rather than [B,T,C,H,W] | |
if self._channels_first: | |
image_tensor = image_tensor.transpose(0, 1) | |
if self._generate_masks and self.mask_generator is not None: | |
mask = self.mask_generator() | |
return image_tensor, mask.bool() | |
else: | |
return image_tensor | |