Spaces:
Running
on
L40S
Running
on
L40S
from typing import List | |
import numpy as np | |
from detrsmpl.utils.path_utils import ( | |
Existence, | |
check_path_existence, | |
check_path_suffix, | |
) | |
from .human_data import HumanData | |
class HumanDataCacheReader(): | |
def __init__(self, npz_path: str): | |
self.npz_path = npz_path | |
npz_file = np.load(npz_path, allow_pickle=True) | |
self.slice_size = npz_file['slice_size'].item() | |
self.data_len = npz_file['data_len'].item() | |
self.keypoints_info = npz_file['keypoints_info'].item() | |
self.non_sliced_data = None | |
self.npz_file = None | |
def __del__(self): | |
if self.npz_file is not None: | |
self.npz_file.close() | |
def get_item(self, index, required_keys: List[str] = []): | |
if self.npz_file is None: | |
self.npz_file = np.load(self.npz_path, allow_pickle=True) | |
cache_key = str(int(index / self.slice_size)) | |
base_data = self.npz_file[cache_key].item() | |
base_data.update(self.keypoints_info) | |
for key in required_keys: | |
non_sliced_value = self.get_non_sliced_data(key) | |
if isinstance(non_sliced_value, dict) and\ | |
key in base_data and\ | |
isinstance(base_data[key], dict): | |
base_data[key].update(non_sliced_value) | |
else: | |
base_data[key] = non_sliced_value | |
ret_human_data = HumanData.new(source_dict=base_data) | |
# data in cache is compressed | |
ret_human_data.__keypoints_compressed__ = True | |
# set missing values and attributes by default method | |
ret_human_data.__set_default_values__() | |
return ret_human_data | |
def get_non_sliced_data(self, key: str): | |
if self.non_sliced_data is None: | |
if self.npz_file is None: | |
npz_file = np.load(self.npz_path, allow_pickle=True) | |
self.non_sliced_data = npz_file['non_sliced_data'].item() | |
else: | |
self.non_sliced_data = self.npz_file['non_sliced_data'].item() | |
return self.non_sliced_data[key] | |
class HumanDataCacheWriter(): | |
def __init__(self, | |
slice_size: int, | |
data_len: int, | |
keypoints_info: dict, | |
non_sliced_data: dict, | |
key_strict: bool = True): | |
self.slice_size = slice_size | |
self.data_len = data_len | |
self.keypoints_info = keypoints_info | |
self.non_sliced_data = non_sliced_data | |
self.sliced_data = {} | |
self.key_strict = key_strict | |
def update_sliced_dict(self, sliced_dict): | |
self.sliced_data.update(sliced_dict) | |
def dump(self, npz_path: str, overwrite: bool = True): | |
"""Dump keys and items to an npz file. | |
Args: | |
npz_path (str): | |
Path to a dumped npz file. | |
overwrite (bool, optional): | |
Whether to overwrite if there is already a file. | |
Defaults to True. | |
Raises: | |
ValueError: | |
npz_path does not end with '.npz'. | |
FileExistsError: | |
When overwrite is False and file exists. | |
""" | |
if not check_path_suffix(npz_path, ['.npz']): | |
raise ValueError('Not an npz file.') | |
if not overwrite: | |
if check_path_existence(npz_path, 'file') == Existence.FileExist: | |
raise FileExistsError | |
dict_to_dump = { | |
'slice_size': self.slice_size, | |
'data_len': self.data_len, | |
'keypoints_info': self.keypoints_info, | |
'non_sliced_data': self.non_sliced_data, | |
'key_strict': self.key_strict, | |
} | |
dict_to_dump.update(self.sliced_data) | |
np.savez_compressed(npz_path, **dict_to_dump) | |