AiOS / detrsmpl /data /datasets /human_video_dataset.py
ttxskk
update
d7e58f0
raw
history blame
6.93 kB
import copy
from typing import Optional, Union
import numpy as np
import torch
from mmcv.parallel import DataContainer as DC
from skimage.util.shape import view_as_windows
from .builder import DATASETS
from .human_image_dataset import HumanImageDataset
def get_vid_name(image_path: str):
"""Get base_dir of the given path."""
content = image_path.split('/')
vid_name = '/'.join(content[:-1])
return vid_name
def split_into_chunks(data_infos: list, seq_len: int, stride: int,
test_mode: bool, only_vid_name: bool):
"""Split annotations into chunks.
Adapted from https://github.com/mkocabas/VIBE
Args:
data_infos (list): parsed annotations.
seq_len (int): the length of each chunk.
stride (int): the interval between chunks.
test_mode (bool): if test_mode is true, then an additional chunk
will be added to cover all frames. Otherwise, last few frames
will be dropped.
only_vid_name (bool): if only_vid_name is true, image_path only
contains the video name. Otherwise, image_path contains both
video_name and frame index.
Return:
list:
shape: [N, 4]. Each chunk contains four parameters: start_frame,
end_frame, valid_start_frame, valid_end_frame. The last two
parameters are used to suppress redundant frames.
"""
vid_names = []
for image_path in data_infos:
if only_vid_name:
vid_name = image_path
else:
vid_name = get_vid_name(image_path)
vid_names.append(vid_name)
vid_names = np.array(vid_names)
video_start_end_indices = []
video_names, group = np.unique(vid_names, return_index=True)
perm = np.argsort(group)
video_names, group = video_names[perm], group[perm]
indices = np.split(np.arange(0, vid_names.shape[0]), group[1:])
for idx in range(len(video_names)):
indexes = indices[idx]
if indexes.shape[0] < seq_len:
continue
chunks = view_as_windows(indexes, (seq_len, ), step=stride)
start_finish = chunks[:, (0, -1, 0, -1)].tolist()
video_start_end_indices += start_finish
if chunks[-1][-1] < indexes[-1] and test_mode:
start_frame = indexes[-1] - seq_len + 1
end_frame = indexes[-1]
valid_start_frame = chunks[-1][-1] + 1
valid_end_frame = indexes[-1]
extra_start_finish = [[
start_frame, end_frame, valid_start_frame, valid_end_frame
]]
video_start_end_indices += extra_start_finish
return video_start_end_indices
@DATASETS.register_module()
class HumanVideoDataset(HumanImageDataset):
"""Human Video Dataset.
Args:
data_prefix (str): the prefix of data path.
pipeline (list): a list of dict, where each element represents
a operation defined in `mmhuman3d.datasets.pipelines`.
dataset_name (str | None): the name of dataset. It is used to
identify the type of evaluation metric. Default: None.
seq_len (int, optional): the length of input sequence. Default: 16.
overlap (float, optional): the overlap between different sequences.
Default: 0
only_vid_name (bool, optional): the format of image_path.
If only_vid_name is true, image_path only contains the video
name. Otherwise, image_path contains both video_name and frame
index.
body_model (dict | None, optional): the config for body model,
which will be used to generate meshes and keypoints.
Default: None.
ann_file (str | None, optional): the annotation file. When ann_file
is str, the subclass is expected to read from the ann_file. When
ann_file is None, the subclass is expected to read according
to data_prefix.
convention (str, optional): keypoints convention. Keypoints will be
converted from "human_data" to the given one.
Default: "human_data"
test_mode (bool, optional): in train mode or test mode. Default: False.
"""
def __init__(self,
data_prefix: str,
pipeline: list,
dataset_name: str,
seq_len: Optional[int] = 16,
overlap: Optional[float] = 0.,
only_vid_name: Optional[bool] = False,
body_model: Optional[Union[dict, None]] = None,
ann_file: Optional[Union[str, None]] = None,
convention: Optional[str] = 'human_data',
test_mode: Optional[bool] = False):
super(HumanVideoDataset, self).__init__(data_prefix=data_prefix,
pipeline=pipeline,
dataset_name=dataset_name,
body_model=body_model,
convention=convention,
ann_file=ann_file,
test_mode=test_mode)
self.seq_len = seq_len
self.stride = int(seq_len * (1 - overlap))
self.vid_indices = split_into_chunks(self.human_data['image_path'],
self.seq_len, self.stride,
test_mode, only_vid_name)
self.vid_indices = np.array(self.vid_indices)
def __len__(self):
return len(self.vid_indices)
def prepare_data(self, idx: int):
"""Prepare data for each chunk.
Step 1: get annotation from each frame. Step 2: add metas of each
chunk.
"""
start_idx, end_idx = self.vid_indices[idx][:2]
batch_results = []
image_path = []
for frame_idx in range(start_idx, end_idx + 1):
frame_results = copy.deepcopy(self.prepare_raw_data(frame_idx))
image_path.append(frame_results.pop('image_path'))
if 'features' in self.human_data:
frame_results['features'] = \
copy.deepcopy(self.human_data['features'][frame_idx])
frame_results = self.pipeline(frame_results)
batch_results.append(frame_results)
video_results = {}
for key in batch_results[0].keys():
batch_anno = []
for item in batch_results:
batch_anno.append(item[key])
if isinstance(batch_anno[0], torch.Tensor):
batch_anno = torch.stack(batch_anno, dim=0)
video_results[key] = batch_anno
img_metas = {
'frame_idx': self.vid_indices[idx],
'image_path': image_path
}
video_results['img_metas'] = DC(img_metas, cpu_only=True)
return video_results