import logging import pickle from enum import Enum from typing import Any, TypeVar, Union import numpy as np from mmcv.utils import print_log from detrsmpl.data.data_structures.human_data import HumanData from detrsmpl.utils.path_utils import ( Existence, check_path_existence, check_path_suffix, ) # In T = TypeVar('T'), T can be anything. # See definition of typing.TypeVar for details. _HumanData = TypeVar('_HumanData') _MultiHumanData_SUPPORTED_KEYS = HumanData.SUPPORTED_KEYS.copy() _MultiHumanData_SUPPORTED_KEYS.update( {'optional': { 'type': dict, 'slice_key': 'frame_range', 'dim': 0 }}) class _KeyCheck(Enum): PASS = 0 WARN = 1 ERROR = 2 class MultiHumanData(HumanData): SUPPORTED_KEYS = _MultiHumanData_SUPPORTED_KEYS def __new__(cls: _HumanData, *args: Any, **kwargs: Any) -> _HumanData: """New an instance of HumanData. Args: cls (HumanData): HumanData class. Returns: HumanData: An instance of Hu """ ret_human_data = super().__new__(cls, args, kwargs) setattr(ret_human_data, '__data_len__', -1) setattr(ret_human_data, '__instance_num__', -1) setattr(ret_human_data, '__key_strict__', False) setattr(ret_human_data, '__keypoints_compressed__', False) return ret_human_data def load(self, npz_path: str): """Load data from npz_path and update them to self. Args: npz_path (str): Path to a dumped npz file. """ supported_keys = self.__class__.SUPPORTED_KEYS with np.load(npz_path, allow_pickle=True) as npz_file: tmp_data_dict = dict(npz_file) for key, value in list(tmp_data_dict.items()): if isinstance(value, np.ndarray) and\ len(value.shape) == 0: # value is not an ndarray before dump value = value.item() elif key in supported_keys and\ type(value) != supported_keys[key]['type']: value = supported_keys[key]['type'](value) if value is None: tmp_data_dict.pop(key) elif key == '__key_strict__' or \ key == '__data_len__' or\ key == '__instance_num__' or\ key == '__keypoints_compressed__': self.__setattr__(key, value) # pop the attributes to keep dict clean tmp_data_dict.pop(key) elif key == 'bbox_xywh' and value.shape[1] == 4: value = np.hstack([value, np.ones([value.shape[0], 1])]) tmp_data_dict[key] = value else: tmp_data_dict[key] = value self.update(tmp_data_dict) self.__set_default_values__() def dump(self, npz_path: str, overwrite: bool = True): """Dump keys and items to an npz file. Args: npz_path (str): Path to a dumped npz file. overwrite (bool, optional): Whether to overwrite if there is already a file. Defaults to True. Raises: ValueError: npz_path does not end with '.npz'. FileExistsError: When overwrite is False and file exists. """ if not check_path_suffix(npz_path, ['.npz']): raise ValueError('Not an npz file.') if not overwrite: if check_path_existence(npz_path, 'file') == Existence.FileExist: raise FileExistsError dict_to_dump = { '__key_strict__': self.__key_strict__, '__data_len__': self.__data_len__, '__instance_num__': self.__instance_num__, '__keypoints_compressed__': self.__keypoints_compressed__, } dict_to_dump.update(self) np.savez_compressed(npz_path, **dict_to_dump) def dump_by_pickle(self, pkl_path: str, overwrite: bool = True) -> None: """Dump keys and items to a pickle file. It's a secondary dump method, when a HumanData instance is too large to be dumped by self.dump() Args: pkl_path (str): Path to a dumped pickle file. overwrite (bool, optional): Whether to overwrite if there is already a file. Defaults to True. Raises: ValueError: npz_path does not end with '.pkl'. FileExistsError: When overwrite is False and file exists. """ if not check_path_suffix(pkl_path, ['.pkl']): raise ValueError('Not an pkl file.') if not overwrite: if check_path_existence(pkl_path, 'file') == Existence.FileExist: raise FileExistsError dict_to_dump = { '__key_strict__': self.__key_strict__, '__data_len__': self.__data_len__, '__instance_num__': self.__instance_num__, '__keypoints_compressed__': self.__keypoints_compressed__, } dict_to_dump.update(self) with open(pkl_path, 'wb') as f_writeb: pickle.dump(dict_to_dump, f_writeb, protocol=pickle.HIGHEST_PROTOCOL) def load_by_pickle(self, pkl_path: str) -> None: """Load data from pkl_path and update them to self. When a HumanData Instance was dumped by self.dump_by_pickle(), use this to load. Args: npz_path (str): Path to a dumped npz file. """ with open(pkl_path, 'rb') as f_readb: tmp_data_dict = pickle.load(f_readb) for key, value in list(tmp_data_dict.items()): if value is None: tmp_data_dict.pop(key) elif key == '__key_strict__' or \ key == '__data_len__' or\ key == '__instance_num__' or\ key == '__keypoints_compressed__': self.__setattr__(key, value) # pop the attributes to keep dict clean tmp_data_dict.pop(key) elif key == 'bbox_xywh' and value.shape[1] == 4: value = np.hstack([value, np.ones([value.shape[0], 1])]) tmp_data_dict[key] = value else: tmp_data_dict[key] = value self.update(tmp_data_dict) self.__set_default_values__() @property def instance_num(self) -> int: """Get the human instance num of this MultiHumanData instance. In MuliHumanData, an image may have multiple corresponding human instances. Returns: int: Number of human instance related to this instance. """ return self.__instance_num__ @instance_num.setter def instance_num(self, value: int): """Set the human instance num of this MultiHumanData instance. Args: value (int): Number of human instance related to this instance. """ self.__instance_num__ = value def get_slice(self, arg_0: int, arg_1: Union[int, Any] = None, step: int = 1) -> _HumanData: """Slice all sliceable values along major_dim dimension. Args: arg_0 (int): When arg_1 is None, arg_0 is stop and start=0. When arg_1 is not None, arg_0 is start. arg_1 (Union[int, Any], optional): None or where to stop. Defaults to None. step (int, optional): Length of step. Defaults to 1. Returns: MultiHumanData: A new MultiHumanData instance with sliced values. """ ret_human_data = \ MultiHumanData.new(key_strict=self.get_key_strict()) if arg_1 is None: start = 0 stop = arg_0 else: start = arg_0 stop = arg_1 slice_index = slice(start, stop, step) dim_dict = self.__get_slice_dim__() # frame_range = self.get_raw_value('optional')['frame_range'] for key, dim in dim_dict.items(): # primary index if key == 'optional': frame_range = None else: frame_range = self.get_raw_value('optional')['frame_range'] # keys not expected be sliced if dim is None: ret_human_data[key] = self[key] elif isinstance(dim, dict): value_dict = self.get_raw_value(key) sliced_dict = {} for sub_key in value_dict.keys(): sub_value = value_dict[sub_key] if dim[sub_key] is None: sliced_dict[sub_key] = sub_value else: sub_dim = dim[sub_key] sliced_sub_value = \ MultiHumanData.__get_sliced_result__( sub_value, sub_dim, slice_index, frame_range) sliced_dict[sub_key] = sliced_sub_value ret_human_data[key] = sliced_dict else: value = self[key] sliced_value = \ MultiHumanData.__get_sliced_result__( value, dim, slice_index, frame_range) ret_human_data[key] = sliced_value # check keypoints compressed if self.check_keypoints_compressed(): ret_human_data.compress_keypoints_by_mask() return ret_human_data def __get_slice_dim__(self) -> dict: """For each key in this HumanData, get the dimension for slicing. 0 for default, if no other value specified. Returns: dict: Keys are self.keys(). Values indicate where to slice. None for not expected to be sliced or failed. """ supported_keys = self.__class__.SUPPORTED_KEYS ret_dict = {} for key in self.keys(): # keys not expected be sliced if key in supported_keys and \ 'dim' in supported_keys[key] and \ supported_keys[key]['dim'] is None: ret_dict[key] = None else: value = self[key] if isinstance(value, dict) and len(value) > 0: ret_dict[key] = {} for sub_key in value.keys(): try: sub_value_len = len(value[sub_key]) if sub_value_len != self.instance_num and \ sub_value_len != self.data_len: ret_dict[key][sub_key] = None elif 'dim' in value: ret_dict[key][sub_key] = value['dim'] else: ret_dict[key][sub_key] = 0 except TypeError: ret_dict[key][sub_key] = None continue # instance cannot be sliced without len method try: value_len = len(value) except TypeError: ret_dict[key] = None continue # slice on dim 0 by default slice_dim = 0 if key in supported_keys and \ 'dim' in supported_keys[key]: slice_dim = \ supported_keys[key]['dim'] data_len = value_len if slice_dim == 0 \ else value.shape[slice_dim] # dim not for slice if data_len != self.__instance_num__: ret_dict[key] = None continue else: ret_dict[key] = slice_dim return ret_dict # TODO: to support cache def __check_value_len__(self, key: Any, val: Any) -> bool: """Check whether the temporal length of val matches other values. Args: key (Any): Key in MultiHumanData. val (Any): Value to the key. Returns: bool: If temporal dim is defined and temporal length doesn't match, return False. Else return True. """ ret_bool = True supported_keys = self.__class__.SUPPORTED_KEYS # MultiHumanData instance_num = 0 if key == 'optional' and \ 'frame_range' in val: for frame_range in val['frame_range']: instance_num += (frame_range[-1] - frame_range[0]) if self.instance_num == -1: # init instance_num for multi_human_data self.instance_num = instance_num elif self.instance_num != instance_num: ret_bool = False data_len = len(val['frame_range']) if self.data_len == -1: # init data_len self.data_len = data_len elif self.data_len == self.instance_num: # update data_len self.data_len = data_len elif self.data_len != self.instance_num: ret_bool = False # check definition elif key in supported_keys: # check data length if 'dim' in supported_keys[key] and \ supported_keys[key]['dim'] is not None: val_slice_dim = supported_keys[key]['dim'] if supported_keys[key]['type'] == dict: slice_key = supported_keys[key]['slice_key'] val_data_len = val[slice_key].shape[val_slice_dim] else: val_data_len = val.shape[val_slice_dim] if self.instance_num < 0: # Init instance_num for HumanData, # which is equal to data_len. self.instance_num = val_data_len else: # check if val_data_len matches recorded instance_num if self.instance_num != val_data_len: ret_bool = False if self.data_len < 0: # init data_len for HumanData, it's equal to # instance_num. # If it's MultiHumanData needs to be updated self.data_len = val_data_len if not ret_bool: err_msg = 'Data length check Failed:\n' err_msg += f'key={str(key)}\n' if self.data_len != self.instance_num: err_msg += f'val\'s instance_num={self.data_len}\n' err_msg += f'expected instance_num={self.instance_num}\n' print_log(msg=err_msg, logger=self.__class__.logger, level=logging.ERROR) return ret_bool def __set_default_values__(self) -> None: """For older versions of HumanData, call this method to apply missing values (also attributes). Note: 1. Older HumanData doesn't define `data_len`. 2. In the newer HumanData, `data_len` equals the `instances_num`. 3. In MultiHumanData, `instance_num` equals instances num, and `data_len` equals frames num. """ supported_keys = self.__class__.SUPPORTED_KEYS if self.instance_num == -1: # the loaded file is not multi_human_data for key in supported_keys: if key in self and \ 'dim' in supported_keys[key] and\ supported_keys[key]['dim'] is not None: if 'slice_key' in supported_keys[key] and\ supported_keys[key]['type'] == dict: sub_key = supported_keys[key]['slice_key'] slice_dim = supported_keys[key]['dim'] self.instance_num = self[key][sub_key].shape[slice_dim] else: slice_dim = supported_keys[key]['dim'] self.instance_num = self[key].shape[slice_dim] # convert HumanData to MultiHumanData self.data_len = self.instance_num optional = {} optional['frame_range'] = \ [[i, i + 1] for i in range(self.data_len)] self['optional'] = optional break for key in list(self.keys()): convention_key = f'{key}_convention' if key.startswith('keypoints') and \ not key.endswith('_mask') and \ not key.endswith('_convention') and \ convention_key not in self: self[convention_key] = 'human_data' @classmethod def __get_sliced_result__( cls, input_data: Union[np.ndarray, list, tuple], slice_dim: int, slice_range: slice, frame_index: list = None) -> Union[np.ndarray, list, tuple]: if frame_index is not None: slice_data = [] for frame_range in frame_index[slice_range]: slice_index = slice(frame_range[0], frame_range[-1], 1) slice_result = \ HumanData.__get_sliced_result__( input_data, slice_dim, slice_index) for element in slice_result: slice_data.append(element) if isinstance(input_data, np.ndarray): slice_data = np.array(slice_data) else: slice_data = type(input_data)(slice_data) else: # primary index slice_data = \ HumanData.__get_sliced_result__( input_data, slice_dim, slice_range) return slice_data