File size: 17,928 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
import os
import decord
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


class VideoMAE(torch.utils.data.Dataset):
    """Load your own video classification dataset.
    Parameters
    ----------
    root : str, required.
        Path to the root folder storing the dataset.
    setting : str, required.
        A text file describing the dataset, each line per video sample.
        There are three items in each line: (1) video path; (2) video length and (3) video label.
    train : bool, default True.
        Whether to load the training or validation set.
    test_mode : bool, default False.
        Whether to perform evaluation on the test set.
        Usually there is three-crop or ten-crop evaluation strategy involved.
    name_pattern : str, default None.
        The naming pattern of the decoded video frames.
        For example, img_00012.jpg.
    video_ext : str, default 'mp4'.
        If video_loader is set to True, please specify the video format accordinly.
    is_color : bool, default True.
        Whether the loaded image is color or grayscale.
    modality : str, default 'rgb'.
        Input modalities, we support only rgb video frames for now.
        Will add support for rgb difference image and optical flow image later.
    num_segments : int, default 1.
        Number of segments to evenly divide the video into clips.
        A useful technique to obtain global video-level information.
        Limin Wang, etal, Temporal Segment Networks: Towards Good Practices for Deep Action Recognition, ECCV 2016.
    num_crop : int, default 1.
        Number of crops for each image. default is 1.
        Common choices are three crops and ten crops during evaluation.
    new_length : int, default 1.
        The length of input video clip. Default is a single image, but it can be multiple video frames.
        For example, new_length=16 means we will extract a video clip of consecutive 16 frames.
    new_step : int, default 1.
        Temporal sampling rate. For example, new_step=1 means we will extract a video clip of consecutive frames.
        new_step=2 means we will extract a video clip of every other frame.
    temporal_jitter : bool, default False.
        Whether to temporally jitter if new_step > 1.
    video_loader : bool, default False.
        Whether to use video loader to load data.
    use_decord : bool, default True.
        Whether to use Decord video loader to load data. Otherwise use mmcv video loader.
    transform : function, default None.
        A function that takes data and label and transforms them.
    data_aug : str, default 'v1'.
        Different types of data augmentation auto. Supports v1, v2, v3 and v4.
    lazy_init : bool, default False.
        If set to True, build a dataset instance without loading any dataset.
    """

    def __init__(self,
                 root,
                 setting,
                 train=True,
                 test_mode=False,
                 name_pattern='img_%05d.jpg',
                 video_ext='mp4',
                 is_color=True,
                 modality='rgb',
                 num_segments=1,
                 num_crop=1,
                 new_length=1,
                 new_step=1,
                 randomize_interframes=False,
                 transform=None,
                 temporal_jitter=False,
                 video_loader=False,
                 use_decord=False,
                 lazy_init=False,
                 is_video_dataset=True):

        super(VideoMAE, self).__init__()
        self.root = root
        self.setting = setting
        self.train = train
        self.test_mode = test_mode
        self.is_color = is_color
        self.modality = modality
        self.num_segments = num_segments
        self.num_crop = num_crop
        self.new_length = new_length

        self.randomize_interframes = randomize_interframes
        self._new_step = new_step  # If randomize_interframes is True, then this is the max, otherwise it's just the skip
        # self._skip_length = self.new_length * self.new_step # If randomize_interframes is True, then this isn't used, otherwise it's used as calculated
        self.temporal_jitter = temporal_jitter
        self.name_pattern = name_pattern
        self.video_loader = video_loader
        self.video_ext = video_ext
        self.use_decord = use_decord
        self.transform = transform
        self.lazy_init = lazy_init

        if (not self.lazy_init) and is_video_dataset:
            self.clips = self._make_dataset(root, setting)
            if len(self.clips) == 0:
                raise (RuntimeError("Found 0 video clips in subfolders of: " + root + "\n"
                                                                                      "Check your data directory (opt.data-dir)."))

    def __getitem__(self, index):

        directory, target = self.clips[index]

        if self.video_loader:
            if '.' in directory.split('/')[-1]:
                # data in the "setting" file already have extension, e.g., demo.mp4
                video_name = directory
            else:
                # data in the "setting" file do not have extension, e.g., demo
                # So we need to provide extension (i.e., .mp4) to complete the file name.
                video_name = '{}.{}'.format(directory, self.video_ext)

            try:
                decord_vr = decord.VideoReader(video_name, num_threads=1)
            except:
                # return video_name
                return (self.__getitem__(index + 1))
            duration = len(decord_vr)

        segment_indices, skip_offsets, new_step, skip_length = self._sample_train_indices(duration)

        images = self._video_TSN_decord_batch_loader(directory, decord_vr, duration, segment_indices, skip_offsets,
                                                     new_step, skip_length)

        process_data, mask = self.transform((images, None))  # T*C,H,W
        process_data = process_data.view((self.new_length, 3) + process_data.size()[-2:]).transpose(0,
                                                                                                    1)  # T*C,H,W -> T,C,H,W -> C,T,H,W

        return (process_data, mask)

    def __len__(self):
        return len(self.clips)

    def _make_dataset(self, directory, setting):
        if not os.path.exists(setting):
            raise (RuntimeError("Setting file %s doesn't exist. Check opt.train-list and opt.val-list. " % (setting)))
        clips = []
        with open(setting) as split_f:
            data = split_f.readlines()
            for line in data:
                line_info = line.split(' ')
                # line format: video_path, video_duration, video_label
                if len(line_info) < 2:
                    raise (RuntimeError('Video input format is not correct, missing one or more element. %s' % line))
                elif len(line_info) > 2:
                    line_info = (' '.join(line_info[:-1]), line_info[-1])  # filename has spaces
                clip_path = os.path.join(line_info[0])
                target = int(line_info[1])
                item = (clip_path, target)
                clips.append(item)
        # import torch_xla.core.xla_model as xm
        # print = xm.master_print
        # print("Dataset created. Number of clips: ", len(clips))
        return clips

    def _sample_train_indices(self, num_frames):
        if self.randomize_interframes is False:
            new_step = self._new_step
        else:
            new_step = np.random.randint(1, self._new_step + 1)

        skip_length = self.new_length * new_step

        average_duration = (num_frames - skip_length + 1) // self.num_segments
        if average_duration > 0:
            offsets = np.multiply(list(range(self.num_segments)),
                                  average_duration)
            offsets = offsets + np.random.randint(average_duration,
                                                  size=self.num_segments)
        elif num_frames > max(self.num_segments, skip_length):
            offsets = np.sort(np.random.randint(
                num_frames - skip_length + 1,
                size=self.num_segments))
        else:
            offsets = np.zeros((self.num_segments,))

        if self.temporal_jitter:
            skip_offsets = np.random.randint(
                new_step, size=skip_length // new_step)
        else:
            skip_offsets = np.zeros(
                skip_length // new_step, dtype=int)
        return offsets + 1, skip_offsets, new_step, skip_length

    def _video_TSN_decord_batch_loader(self, directory, video_reader, duration, indices, skip_offsets, new_step,
                                       skip_length):
        sampled_list = []
        frame_id_list = []
        for seg_ind in indices:
            offset = int(seg_ind)
            for i, _ in enumerate(range(0, skip_length, new_step)):
                if offset + skip_offsets[i] <= duration:
                    frame_id = offset + skip_offsets[i] - 1
                else:
                    frame_id = offset - 1
                frame_id_list.append(frame_id)
                if offset + new_step < duration:
                    offset += new_step
        try:
            video_data = video_reader.get_batch(frame_id_list).asnumpy()
            sampled_list = [Image.fromarray(video_data[vid, :, :, :]).convert('RGB') for vid, _ in
                            enumerate(frame_id_list)]
        except:
            raise RuntimeError(
                'Error occured in reading frames {} from video {} of duration {}.'.format(frame_id_list, directory,
                                                                                          duration))
        return sampled_list


class ContextAndTargetVideoDataset(VideoMAE):
    """
    A video dataset whose provided videos consist of (1) a "context" sequence of length Tc
    and (2) a "target" sequence Tt. 

    These two sequences have the same frame rate (specificiable in real units) but are 
    separated by a specified gap (which may vary for different examples.)

    The main use case is for training models to predict ahead by some variable amount,
    given the context.
    """

    standard_fps = [12, 24, 30, 48, 60, 100]

    def __init__(self,
                 root,
                 setting,
                 train=True,
                 test_mode=False,
                 transform=None,
                 step_units='ms',
                 new_step=150,
                 start_frame=0,
                 context_length=2,
                 target_length=1,
                 channels_first=True,
                 generate_masks=True,
                 mask_generator=None,
                 context_target_gap=[400, 600],
                 normalize_timestamps=True,
                 default_fps=30,
                 min_fps=0.1,
                 seed=0,
                 *args,
                 **kwargs):
        super(ContextAndTargetVideoDataset, self).__init__(
            root=root,
            setting=setting,
            train=train,
            test_mode=test_mode,
            transform=transform,
            new_length=context_length,
            use_decord=True,
            lazy_init=False,
            video_loader=True,
            *args, **kwargs)

        # breakpoint()

        self.context_length = self.new_length
        self.target_length = target_length

        ## convert from fps and step size to frames
        self._fps = None
        self._min_fps = min_fps
        self._default_fps = default_fps
        self._step_units = step_units
        self.new_step = new_step

        ## sampling for train and test
        self._start_frame = start_frame
        self.gap = context_target_gap
        self.seed = seed
        self.rng = np.random.RandomState(seed=seed)

        # breakpoint()

        ## output formatting
        self._channels_first = channels_first
        self._normalize_timestamps = normalize_timestamps
        self._generate_masks = generate_masks
        self.mask_generator = mask_generator


    def _get_frames_per_t(self, t):
        if self._step_units == 'frames' or (self._step_units is None):
            return int(t)

        assert self._fps is not None
        t_per_frame = 1 / self._fps
        if self._step_units in ['ms', 'milliseconds']:
            t_per_frame *= 1000.0

        return max(int(np.round(t / t_per_frame)), 1)

    @property
    def new_step(self):
        if self._fps is None:
            return None
        else:
            return self._get_frames_per_t(self._new_step)

    @new_step.setter
    def new_step(self, v):
        self._new_step = v

    @property
    def gap(self):
        if self._fps is None:
            return [1, 2]
        else:
            gap = [self._get_frames_per_t(self._gap[0]),
                   self._get_frames_per_t(self._gap[1])]
            gap[1] = max(gap[1], gap[0] + 1)
            return gap

    @gap.setter
    def gap(self, v):
        if v is None:
            v = self._new_step
        if not isinstance(v, (list, tuple)):
            v = [v, v]
        self._gap = v

    def _get_video_name(self, directory):
        if ''.join(['.', self.video_ext]) in directory.split('/')[-1]:
            # data in the "setting" file has extension, e.g. demo.mpr
            video_name = directory
        else:
            # data doesn't have an extension
            video_name = '{}.{}'.format(directory, self.video_ext)
        return video_name

    def _set_fps(self, reader):
        """click fps to a standard"""
        if self._step_units == 'frames' or self._step_units is None:
            self._fps = None
        else:
            self._fps = None
            fps = reader.get_avg_fps()
            for st in self.standard_fps:
                if (int(np.floor(fps)) == st) or (int(np.ceil(fps)) == st):
                    self._fps = st
            if self._fps is None:
                self._fps = int(np.round(fps))

            if self._fps < self._min_fps:
                self._fps = self._default_fps

    def _get_step_and_gap(self):
        step = self.new_step
        if self.randomize_interframes and self.train:
            step = self.rng.randint(1, step + 1)

        if self.train:
            gap = self.rng.randint(*self.gap)
        else:
            gap = sum(self.gap) // 2
        return (step, gap)

    def _sample_frames(self):
        step, gap = self._get_step_and_gap()

        ## compute total length of sample
        ## e.g. if context_length = 2, step = 1, gap = 10, target_length = 2:
        ## total_length = 2 * 1 + 10 + (2 - 1) * 1 = 13
        ## so len(video) must be >= 13
        self._total_length = self.context_length * step + gap + (self.target_length - 1) * step
        if self._total_length > (self._num_frames - self._start_frame):
            if self.train:
                return None
            else:
                raise ValueError(
                    "movie of length %d starting at fr=%d is too long for video of %d frames" % \
                    (self._total_length, self._start_frame, self._num_frames))

        ## sample the frames randomly (if training) or from the start frame (if test)
        if self.train:
            self.start_frame_now = self.rng.randint(
                min(self._start_frame, self._num_frames - self._total_length),
                self._num_frames - self._total_length + 1)
        else:
            self.start_frame_now = min(self._start_frame, self._num_frames - self._total_length)

        frames = [self.start_frame_now + i * step for i in range(self.context_length)]
        frames += [frames[-1] + gap + i * step for i in range(self.target_length)]

        # breakpoint()

        return frames

    def _decode_frame_images(self, reader, frames):
        try:
            video_data = reader.get_batch(frames).asnumpy()
            video_data = [Image.fromarray(video_data[t, :, :, :]).convert('RGB')
                          for t, _ in enumerate(frames)]
        except:
            raise RuntimeError(
                "Error occurred in reading frames {} from video {} of duration {}".format(
                    frames, self.index, self._num_frames))
        return video_data

    def __getitem__(self, index):

        self.index = index
        self.directory, target = self.clips[index]

        self.video_name = self._get_video_name(self.directory)

        ## build decord loader
        try:
            decord_vr = decord.VideoReader(self.video_name, num_threads=1)
            self._set_fps(decord_vr)
        except:
            # return self.video_name
            return (self.__getitem__(index + 1))

        ## sample the video
        self._num_frames = len(decord_vr)
        self.frames = self._sample_frames()
        if self.frames is None:
            print("no movie of length %d for video idx=%d" % (self._total_length, self.index))
            return self.__getitem__(index + 1)

        ## decode to PIL.Image
        image_list = self._decode_frame_images(decord_vr, self.frames)

        ## postproc to torch.Tensor and mask generation
        if self.transform is None:
            image_tensor = torch.stack([transforms.ToTensor()(img) for img in image_list], 0)
        else:
            image_tensor = self.transform((image_list, None))

            image_tensor = image_tensor.view(self.context_length + self.target_length,  3,  *image_tensor.shape[-2:])

        ## VMAE expects [B,C,T,H,W] rather than [B,T,C,H,W]
        if self._channels_first:
            image_tensor = image_tensor.transpose(0, 1)

        if self._generate_masks and self.mask_generator is not None:
            mask = self.mask_generator()
            return image_tensor, mask.bool()
        else:
            return image_tensor