Spaces:
Running
on
L40S
Running
on
L40S
import logging | |
import pickle | |
from enum import Enum | |
from typing import Any, TypeVar, Union | |
import numpy as np | |
from mmcv.utils import print_log | |
from detrsmpl.data.data_structures.human_data import HumanData | |
from detrsmpl.utils.path_utils import ( | |
Existence, | |
check_path_existence, | |
check_path_suffix, | |
) | |
# In T = TypeVar('T'), T can be anything. | |
# See definition of typing.TypeVar for details. | |
_HumanData = TypeVar('_HumanData') | |
_MultiHumanData_SUPPORTED_KEYS = HumanData.SUPPORTED_KEYS.copy() | |
_MultiHumanData_SUPPORTED_KEYS.update( | |
{'optional': { | |
'type': dict, | |
'slice_key': 'frame_range', | |
'dim': 0 | |
}}) | |
class _KeyCheck(Enum): | |
PASS = 0 | |
WARN = 1 | |
ERROR = 2 | |
class MultiHumanData(HumanData): | |
SUPPORTED_KEYS = _MultiHumanData_SUPPORTED_KEYS | |
def __new__(cls: _HumanData, *args: Any, **kwargs: Any) -> _HumanData: | |
"""New an instance of HumanData. | |
Args: | |
cls (HumanData): HumanData class. | |
Returns: | |
HumanData: An instance of Hu | |
""" | |
ret_human_data = super().__new__(cls, args, kwargs) | |
setattr(ret_human_data, '__data_len__', -1) | |
setattr(ret_human_data, '__instance_num__', -1) | |
setattr(ret_human_data, '__key_strict__', False) | |
setattr(ret_human_data, '__keypoints_compressed__', False) | |
return ret_human_data | |
def load(self, npz_path: str): | |
"""Load data from npz_path and update them to self. | |
Args: | |
npz_path (str): | |
Path to a dumped npz file. | |
""" | |
supported_keys = self.__class__.SUPPORTED_KEYS | |
with np.load(npz_path, allow_pickle=True) as npz_file: | |
tmp_data_dict = dict(npz_file) | |
for key, value in list(tmp_data_dict.items()): | |
if isinstance(value, np.ndarray) and\ | |
len(value.shape) == 0: | |
# value is not an ndarray before dump | |
value = value.item() | |
elif key in supported_keys and\ | |
type(value) != supported_keys[key]['type']: | |
value = supported_keys[key]['type'](value) | |
if value is None: | |
tmp_data_dict.pop(key) | |
elif key == '__key_strict__' or \ | |
key == '__data_len__' or\ | |
key == '__instance_num__' or\ | |
key == '__keypoints_compressed__': | |
self.__setattr__(key, value) | |
# pop the attributes to keep dict clean | |
tmp_data_dict.pop(key) | |
elif key == 'bbox_xywh' and value.shape[1] == 4: | |
value = np.hstack([value, np.ones([value.shape[0], 1])]) | |
tmp_data_dict[key] = value | |
else: | |
tmp_data_dict[key] = value | |
self.update(tmp_data_dict) | |
self.__set_default_values__() | |
def dump(self, npz_path: str, overwrite: bool = True): | |
"""Dump keys and items to an npz file. | |
Args: | |
npz_path (str): | |
Path to a dumped npz file. | |
overwrite (bool, optional): | |
Whether to overwrite if there is already a file. | |
Defaults to True. | |
Raises: | |
ValueError: | |
npz_path does not end with '.npz'. | |
FileExistsError: | |
When overwrite is False and file exists. | |
""" | |
if not check_path_suffix(npz_path, ['.npz']): | |
raise ValueError('Not an npz file.') | |
if not overwrite: | |
if check_path_existence(npz_path, 'file') == Existence.FileExist: | |
raise FileExistsError | |
dict_to_dump = { | |
'__key_strict__': self.__key_strict__, | |
'__data_len__': self.__data_len__, | |
'__instance_num__': self.__instance_num__, | |
'__keypoints_compressed__': self.__keypoints_compressed__, | |
} | |
dict_to_dump.update(self) | |
np.savez_compressed(npz_path, **dict_to_dump) | |
def dump_by_pickle(self, pkl_path: str, overwrite: bool = True) -> None: | |
"""Dump keys and items to a pickle file. It's a secondary dump method, | |
when a HumanData instance is too large to be dumped by self.dump() | |
Args: | |
pkl_path (str): | |
Path to a dumped pickle file. | |
overwrite (bool, optional): | |
Whether to overwrite if there is already a file. | |
Defaults to True. | |
Raises: | |
ValueError: | |
npz_path does not end with '.pkl'. | |
FileExistsError: | |
When overwrite is False and file exists. | |
""" | |
if not check_path_suffix(pkl_path, ['.pkl']): | |
raise ValueError('Not an pkl file.') | |
if not overwrite: | |
if check_path_existence(pkl_path, 'file') == Existence.FileExist: | |
raise FileExistsError | |
dict_to_dump = { | |
'__key_strict__': self.__key_strict__, | |
'__data_len__': self.__data_len__, | |
'__instance_num__': self.__instance_num__, | |
'__keypoints_compressed__': self.__keypoints_compressed__, | |
} | |
dict_to_dump.update(self) | |
with open(pkl_path, 'wb') as f_writeb: | |
pickle.dump(dict_to_dump, | |
f_writeb, | |
protocol=pickle.HIGHEST_PROTOCOL) | |
def load_by_pickle(self, pkl_path: str) -> None: | |
"""Load data from pkl_path and update them to self. | |
When a HumanData Instance was dumped by | |
self.dump_by_pickle(), use this to load. | |
Args: | |
npz_path (str): | |
Path to a dumped npz file. | |
""" | |
with open(pkl_path, 'rb') as f_readb: | |
tmp_data_dict = pickle.load(f_readb) | |
for key, value in list(tmp_data_dict.items()): | |
if value is None: | |
tmp_data_dict.pop(key) | |
elif key == '__key_strict__' or \ | |
key == '__data_len__' or\ | |
key == '__instance_num__' or\ | |
key == '__keypoints_compressed__': | |
self.__setattr__(key, value) | |
# pop the attributes to keep dict clean | |
tmp_data_dict.pop(key) | |
elif key == 'bbox_xywh' and value.shape[1] == 4: | |
value = np.hstack([value, np.ones([value.shape[0], 1])]) | |
tmp_data_dict[key] = value | |
else: | |
tmp_data_dict[key] = value | |
self.update(tmp_data_dict) | |
self.__set_default_values__() | |
def instance_num(self) -> int: | |
"""Get the human instance num of this MultiHumanData instance. In | |
MuliHumanData, an image may have multiple corresponding human | |
instances. | |
Returns: | |
int: | |
Number of human instance related to this instance. | |
""" | |
return self.__instance_num__ | |
def instance_num(self, value: int): | |
"""Set the human instance num of this MultiHumanData instance. | |
Args: | |
value (int): | |
Number of human instance related to this instance. | |
""" | |
self.__instance_num__ = value | |
def get_slice(self, | |
arg_0: int, | |
arg_1: Union[int, Any] = None, | |
step: int = 1) -> _HumanData: | |
"""Slice all sliceable values along major_dim dimension. | |
Args: | |
arg_0 (int): | |
When arg_1 is None, arg_0 is stop and start=0. | |
When arg_1 is not None, arg_0 is start. | |
arg_1 (Union[int, Any], optional): | |
None or where to stop. | |
Defaults to None. | |
step (int, optional): | |
Length of step. Defaults to 1. | |
Returns: | |
MultiHumanData: | |
A new MultiHumanData instance with sliced values. | |
""" | |
ret_human_data = \ | |
MultiHumanData.new(key_strict=self.get_key_strict()) | |
if arg_1 is None: | |
start = 0 | |
stop = arg_0 | |
else: | |
start = arg_0 | |
stop = arg_1 | |
slice_index = slice(start, stop, step) | |
dim_dict = self.__get_slice_dim__() | |
# frame_range = self.get_raw_value('optional')['frame_range'] | |
for key, dim in dim_dict.items(): | |
# primary index | |
if key == 'optional': | |
frame_range = None | |
else: | |
frame_range = self.get_raw_value('optional')['frame_range'] | |
# keys not expected be sliced | |
if dim is None: | |
ret_human_data[key] = self[key] | |
elif isinstance(dim, dict): | |
value_dict = self.get_raw_value(key) | |
sliced_dict = {} | |
for sub_key in value_dict.keys(): | |
sub_value = value_dict[sub_key] | |
if dim[sub_key] is None: | |
sliced_dict[sub_key] = sub_value | |
else: | |
sub_dim = dim[sub_key] | |
sliced_sub_value = \ | |
MultiHumanData.__get_sliced_result__( | |
sub_value, sub_dim, slice_index, frame_range) | |
sliced_dict[sub_key] = sliced_sub_value | |
ret_human_data[key] = sliced_dict | |
else: | |
value = self[key] | |
sliced_value = \ | |
MultiHumanData.__get_sliced_result__( | |
value, dim, slice_index, frame_range) | |
ret_human_data[key] = sliced_value | |
# check keypoints compressed | |
if self.check_keypoints_compressed(): | |
ret_human_data.compress_keypoints_by_mask() | |
return ret_human_data | |
def __get_slice_dim__(self) -> dict: | |
"""For each key in this HumanData, get the dimension for slicing. 0 for | |
default, if no other value specified. | |
Returns: | |
dict: | |
Keys are self.keys(). | |
Values indicate where to slice. | |
None for not expected to be sliced or | |
failed. | |
""" | |
supported_keys = self.__class__.SUPPORTED_KEYS | |
ret_dict = {} | |
for key in self.keys(): | |
# keys not expected be sliced | |
if key in supported_keys and \ | |
'dim' in supported_keys[key] and \ | |
supported_keys[key]['dim'] is None: | |
ret_dict[key] = None | |
else: | |
value = self[key] | |
if isinstance(value, dict) and len(value) > 0: | |
ret_dict[key] = {} | |
for sub_key in value.keys(): | |
try: | |
sub_value_len = len(value[sub_key]) | |
if sub_value_len != self.instance_num and \ | |
sub_value_len != self.data_len: | |
ret_dict[key][sub_key] = None | |
elif 'dim' in value: | |
ret_dict[key][sub_key] = value['dim'] | |
else: | |
ret_dict[key][sub_key] = 0 | |
except TypeError: | |
ret_dict[key][sub_key] = None | |
continue | |
# instance cannot be sliced without len method | |
try: | |
value_len = len(value) | |
except TypeError: | |
ret_dict[key] = None | |
continue | |
# slice on dim 0 by default | |
slice_dim = 0 | |
if key in supported_keys and \ | |
'dim' in supported_keys[key]: | |
slice_dim = \ | |
supported_keys[key]['dim'] | |
data_len = value_len if slice_dim == 0 \ | |
else value.shape[slice_dim] | |
# dim not for slice | |
if data_len != self.__instance_num__: | |
ret_dict[key] = None | |
continue | |
else: | |
ret_dict[key] = slice_dim | |
return ret_dict | |
# TODO: to support cache | |
def __check_value_len__(self, key: Any, val: Any) -> bool: | |
"""Check whether the temporal length of val matches other values. | |
Args: | |
key (Any): | |
Key in MultiHumanData. | |
val (Any): | |
Value to the key. | |
Returns: | |
bool: | |
If temporal dim is defined and temporal length doesn't match, | |
return False. | |
Else return True. | |
""" | |
ret_bool = True | |
supported_keys = self.__class__.SUPPORTED_KEYS | |
# MultiHumanData | |
instance_num = 0 | |
if key == 'optional' and \ | |
'frame_range' in val: | |
for frame_range in val['frame_range']: | |
instance_num += (frame_range[-1] - frame_range[0]) | |
if self.instance_num == -1: | |
# init instance_num for multi_human_data | |
self.instance_num = instance_num | |
elif self.instance_num != instance_num: | |
ret_bool = False | |
data_len = len(val['frame_range']) | |
if self.data_len == -1: | |
# init data_len | |
self.data_len = data_len | |
elif self.data_len == self.instance_num: | |
# update data_len | |
self.data_len = data_len | |
elif self.data_len != self.instance_num: | |
ret_bool = False | |
# check definition | |
elif key in supported_keys: | |
# check data length | |
if 'dim' in supported_keys[key] and \ | |
supported_keys[key]['dim'] is not None: | |
val_slice_dim = supported_keys[key]['dim'] | |
if supported_keys[key]['type'] == dict: | |
slice_key = supported_keys[key]['slice_key'] | |
val_data_len = val[slice_key].shape[val_slice_dim] | |
else: | |
val_data_len = val.shape[val_slice_dim] | |
if self.instance_num < 0: | |
# Init instance_num for HumanData, | |
# which is equal to data_len. | |
self.instance_num = val_data_len | |
else: | |
# check if val_data_len matches recorded instance_num | |
if self.instance_num != val_data_len: | |
ret_bool = False | |
if self.data_len < 0: | |
# init data_len for HumanData, it's equal to | |
# instance_num. | |
# If it's MultiHumanData needs to be updated | |
self.data_len = val_data_len | |
if not ret_bool: | |
err_msg = 'Data length check Failed:\n' | |
err_msg += f'key={str(key)}\n' | |
if self.data_len != self.instance_num: | |
err_msg += f'val\'s instance_num={self.data_len}\n' | |
err_msg += f'expected instance_num={self.instance_num}\n' | |
print_log(msg=err_msg, | |
logger=self.__class__.logger, | |
level=logging.ERROR) | |
return ret_bool | |
def __set_default_values__(self) -> None: | |
"""For older versions of HumanData, call this method to apply missing | |
values (also attributes). | |
Note: | |
1. Older HumanData doesn't define `data_len`. | |
2. In the newer HumanData, `data_len` equals the `instances_num`. | |
3. In MultiHumanData, `instance_num` equals instances num, | |
and `data_len` equals frames num. | |
""" | |
supported_keys = self.__class__.SUPPORTED_KEYS | |
if self.instance_num == -1: | |
# the loaded file is not multi_human_data | |
for key in supported_keys: | |
if key in self and \ | |
'dim' in supported_keys[key] and\ | |
supported_keys[key]['dim'] is not None: | |
if 'slice_key' in supported_keys[key] and\ | |
supported_keys[key]['type'] == dict: | |
sub_key = supported_keys[key]['slice_key'] | |
slice_dim = supported_keys[key]['dim'] | |
self.instance_num = self[key][sub_key].shape[slice_dim] | |
else: | |
slice_dim = supported_keys[key]['dim'] | |
self.instance_num = self[key].shape[slice_dim] | |
# convert HumanData to MultiHumanData | |
self.data_len = self.instance_num | |
optional = {} | |
optional['frame_range'] = \ | |
[[i, i + 1] for i in range(self.data_len)] | |
self['optional'] = optional | |
break | |
for key in list(self.keys()): | |
convention_key = f'{key}_convention' | |
if key.startswith('keypoints') and \ | |
not key.endswith('_mask') and \ | |
not key.endswith('_convention') and \ | |
convention_key not in self: | |
self[convention_key] = 'human_data' | |
def __get_sliced_result__( | |
cls, | |
input_data: Union[np.ndarray, list, tuple], | |
slice_dim: int, | |
slice_range: slice, | |
frame_index: list = None) -> Union[np.ndarray, list, tuple]: | |
if frame_index is not None: | |
slice_data = [] | |
for frame_range in frame_index[slice_range]: | |
slice_index = slice(frame_range[0], frame_range[-1], 1) | |
slice_result = \ | |
HumanData.__get_sliced_result__( | |
input_data, | |
slice_dim, | |
slice_index) | |
for element in slice_result: | |
slice_data.append(element) | |
if isinstance(input_data, np.ndarray): | |
slice_data = np.array(slice_data) | |
else: | |
slice_data = type(input_data)(slice_data) | |
else: | |
# primary index | |
slice_data = \ | |
HumanData.__get_sliced_result__( | |
input_data, | |
slice_dim, | |
slice_range) | |
return slice_data | |