AiOS / detrsmpl /data /data_structures /human_data_cache.py
ttxskk
update
d7e58f0
raw
history blame
3.76 kB
from typing import List
import numpy as np
from detrsmpl.utils.path_utils import (
Existence,
check_path_existence,
check_path_suffix,
)
from .human_data import HumanData
class HumanDataCacheReader():
def __init__(self, npz_path: str):
self.npz_path = npz_path
npz_file = np.load(npz_path, allow_pickle=True)
self.slice_size = npz_file['slice_size'].item()
self.data_len = npz_file['data_len'].item()
self.keypoints_info = npz_file['keypoints_info'].item()
self.non_sliced_data = None
self.npz_file = None
def __del__(self):
if self.npz_file is not None:
self.npz_file.close()
def get_item(self, index, required_keys: List[str] = []):
if self.npz_file is None:
self.npz_file = np.load(self.npz_path, allow_pickle=True)
cache_key = str(int(index / self.slice_size))
base_data = self.npz_file[cache_key].item()
base_data.update(self.keypoints_info)
for key in required_keys:
non_sliced_value = self.get_non_sliced_data(key)
if isinstance(non_sliced_value, dict) and\
key in base_data and\
isinstance(base_data[key], dict):
base_data[key].update(non_sliced_value)
else:
base_data[key] = non_sliced_value
ret_human_data = HumanData.new(source_dict=base_data)
# data in cache is compressed
ret_human_data.__keypoints_compressed__ = True
# set missing values and attributes by default method
ret_human_data.__set_default_values__()
return ret_human_data
def get_non_sliced_data(self, key: str):
if self.non_sliced_data is None:
if self.npz_file is None:
npz_file = np.load(self.npz_path, allow_pickle=True)
self.non_sliced_data = npz_file['non_sliced_data'].item()
else:
self.non_sliced_data = self.npz_file['non_sliced_data'].item()
return self.non_sliced_data[key]
class HumanDataCacheWriter():
def __init__(self,
slice_size: int,
data_len: int,
keypoints_info: dict,
non_sliced_data: dict,
key_strict: bool = True):
self.slice_size = slice_size
self.data_len = data_len
self.keypoints_info = keypoints_info
self.non_sliced_data = non_sliced_data
self.sliced_data = {}
self.key_strict = key_strict
def update_sliced_dict(self, sliced_dict):
self.sliced_data.update(sliced_dict)
def dump(self, npz_path: str, overwrite: bool = True):
"""Dump keys and items to an npz file.
Args:
npz_path (str):
Path to a dumped npz file.
overwrite (bool, optional):
Whether to overwrite if there is already a file.
Defaults to True.
Raises:
ValueError:
npz_path does not end with '.npz'.
FileExistsError:
When overwrite is False and file exists.
"""
if not check_path_suffix(npz_path, ['.npz']):
raise ValueError('Not an npz file.')
if not overwrite:
if check_path_existence(npz_path, 'file') == Existence.FileExist:
raise FileExistsError
dict_to_dump = {
'slice_size': self.slice_size,
'data_len': self.data_len,
'keypoints_info': self.keypoints_info,
'non_sliced_data': self.non_sliced_data,
'key_strict': self.key_strict,
}
dict_to_dump.update(self.sliced_data)
np.savez_compressed(npz_path, **dict_to_dump)