Spaces:
Starting
on
L40S
Starting
on
L40S
import os | |
from abc import ABCMeta | |
from typing import Optional, Union | |
import numpy as np | |
from .base_dataset import BaseDataset | |
from .builder import DATASETS | |
class MeshDataset(BaseDataset, metaclass=ABCMeta): | |
"""Mesh Dataset. This dataset only contains smpl data. | |
Args: | |
data_prefix (str): the prefix of data path. | |
pipeline (list): a list of dict, where each element represents | |
a operation defined in `detrsmpl.datasets.pipelines`. | |
dataset_name (str | None): the name of dataset. It is used to | |
identify the type of evaluation metric. Default: None. | |
ann_file (str | None, optional): the annotation file. When ann_file | |
is str, the subclass is expected to read from the ann_file. When | |
ann_file is None, the subclass is expected to read according | |
to data_prefix. | |
test_mode (bool, optional): in train mode or test mode. Default: False. | |
""" | |
def __init__(self, | |
data_prefix: str, | |
pipeline: list, | |
dataset_name: str, | |
ann_file: Optional[Union[str, None]] = None, | |
test_mode: Optional[bool] = False): | |
self.dataset_name = dataset_name | |
super(MeshDataset, self).__init__(data_prefix=data_prefix, | |
pipeline=pipeline, | |
ann_file=ann_file, | |
test_mode=test_mode) | |
def get_annotation_file(self): | |
ann_prefix = os.path.join(self.data_prefix, 'preprocessed_datasets') | |
self.ann_file = os.path.join(ann_prefix, self.ann_file) | |
def load_annotations(self): | |
self.get_annotation_file() | |
data = np.load(self.ann_file, allow_pickle=True) | |
self.smpl = data['smpl'].item() | |
num_data = self.smpl['global_orient'].shape[0] | |
if 'transl' not in self.smpl: | |
self.smpl['transl'] = np.zeros((num_data, 3)) | |
self.has_smpl = np.ones((num_data)) | |
data_infos = [] | |
for idx in range(num_data): | |
info = {} | |
for k, v in self.smpl.items(): | |
info['smpl_' + k] = v[idx] | |
data_infos.append(info) | |
self.num_data = len(data_infos) | |
self.data_infos = data_infos | |