File size: 6,926 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import copy
from typing import Optional, Union

import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from skimage.util.shape import view_as_windows

from .builder import DATASETS
from .human_image_dataset import HumanImageDataset


def get_vid_name(image_path: str):
    """Get base_dir of the given path."""
    content = image_path.split('/')
    vid_name = '/'.join(content[:-1])
    return vid_name


def split_into_chunks(data_infos: list, seq_len: int, stride: int,
                      test_mode: bool, only_vid_name: bool):
    """Split annotations into chunks.
    Adapted from https://github.com/mkocabas/VIBE
    Args:
        data_infos (list): parsed annotations.
        seq_len (int): the length of each chunk.
        stride (int): the interval between chunks.
        test_mode (bool): if test_mode is true, then an additional chunk
            will be added to cover all frames. Otherwise, last few frames
            will be dropped.
        only_vid_name (bool): if only_vid_name is true, image_path only
            contains the video name. Otherwise, image_path contains both
            video_name and frame index.

    Return:
        list:
            shape: [N, 4]. Each chunk contains four parameters: start_frame,
            end_frame, valid_start_frame, valid_end_frame. The last two
            parameters are used to suppress redundant frames.
    """
    vid_names = []
    for image_path in data_infos:
        if only_vid_name:
            vid_name = image_path
        else:
            vid_name = get_vid_name(image_path)
        vid_names.append(vid_name)
    vid_names = np.array(vid_names)
    video_start_end_indices = []

    video_names, group = np.unique(vid_names, return_index=True)
    perm = np.argsort(group)
    video_names, group = video_names[perm], group[perm]

    indices = np.split(np.arange(0, vid_names.shape[0]), group[1:])

    for idx in range(len(video_names)):
        indexes = indices[idx]
        if indexes.shape[0] < seq_len:
            continue
        chunks = view_as_windows(indexes, (seq_len, ), step=stride)
        start_finish = chunks[:, (0, -1, 0, -1)].tolist()
        video_start_end_indices += start_finish
        if chunks[-1][-1] < indexes[-1] and test_mode:
            start_frame = indexes[-1] - seq_len + 1
            end_frame = indexes[-1]
            valid_start_frame = chunks[-1][-1] + 1
            valid_end_frame = indexes[-1]
            extra_start_finish = [[
                start_frame, end_frame, valid_start_frame, valid_end_frame
            ]]
            video_start_end_indices += extra_start_finish

    return video_start_end_indices


@DATASETS.register_module()
class HumanVideoDataset(HumanImageDataset):
    """Human Video Dataset.

    Args:
        data_prefix (str): the prefix of data path.
        pipeline (list): a list of dict, where each element represents
            a operation defined in `mmhuman3d.datasets.pipelines`.
        dataset_name (str | None): the name of dataset. It is used to
            identify the type of evaluation metric. Default: None.
        seq_len (int, optional): the length of input sequence. Default: 16.
        overlap (float, optional): the overlap between different sequences.
            Default: 0
        only_vid_name (bool, optional): the format of image_path.
            If only_vid_name is true, image_path only contains the video
            name. Otherwise, image_path contains both video_name and frame
            index.
        body_model (dict | None, optional): the config for body model,
            which will be used to generate meshes and keypoints.
            Default: None.
        ann_file (str | None, optional): the annotation file. When ann_file
            is str, the subclass is expected to read from the ann_file. When
            ann_file is None, the subclass is expected to read according
            to data_prefix.
        convention (str, optional): keypoints convention. Keypoints will be
            converted from "human_data" to the given one.
            Default: "human_data"
        test_mode (bool, optional): in train mode or test mode. Default: False.
    """
    def __init__(self,
                 data_prefix: str,
                 pipeline: list,
                 dataset_name: str,
                 seq_len: Optional[int] = 16,
                 overlap: Optional[float] = 0.,
                 only_vid_name: Optional[bool] = False,
                 body_model: Optional[Union[dict, None]] = None,
                 ann_file: Optional[Union[str, None]] = None,
                 convention: Optional[str] = 'human_data',
                 test_mode: Optional[bool] = False):
        super(HumanVideoDataset, self).__init__(data_prefix=data_prefix,
                                                pipeline=pipeline,
                                                dataset_name=dataset_name,
                                                body_model=body_model,
                                                convention=convention,
                                                ann_file=ann_file,
                                                test_mode=test_mode)
        self.seq_len = seq_len
        self.stride = int(seq_len * (1 - overlap))
        self.vid_indices = split_into_chunks(self.human_data['image_path'],
                                             self.seq_len, self.stride,
                                             test_mode, only_vid_name)
        self.vid_indices = np.array(self.vid_indices)

    def __len__(self):
        return len(self.vid_indices)

    def prepare_data(self, idx: int):
        """Prepare data for each chunk.

        Step 1: get annotation from each frame. Step 2: add metas of each
        chunk.
        """
        start_idx, end_idx = self.vid_indices[idx][:2]
        batch_results = []
        image_path = []
        for frame_idx in range(start_idx, end_idx + 1):
            frame_results = copy.deepcopy(self.prepare_raw_data(frame_idx))
            image_path.append(frame_results.pop('image_path'))
            if 'features' in self.human_data:
                frame_results['features'] = \
                     copy.deepcopy(self.human_data['features'][frame_idx])
            frame_results = self.pipeline(frame_results)
            batch_results.append(frame_results)
        video_results = {}
        for key in batch_results[0].keys():
            batch_anno = []
            for item in batch_results:
                batch_anno.append(item[key])
            if isinstance(batch_anno[0], torch.Tensor):
                batch_anno = torch.stack(batch_anno, dim=0)
            video_results[key] = batch_anno
        img_metas = {
            'frame_idx': self.vid_indices[idx],
            'image_path': image_path
        }
        video_results['img_metas'] = DC(img_metas, cpu_only=True)
        return video_results