File size: 3,764 Bytes
d7e58f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import List

import numpy as np

from detrsmpl.utils.path_utils import (
    Existence,
    check_path_existence,
    check_path_suffix,
)
from .human_data import HumanData


class HumanDataCacheReader():
    def __init__(self, npz_path: str):
        self.npz_path = npz_path
        npz_file = np.load(npz_path, allow_pickle=True)
        self.slice_size = npz_file['slice_size'].item()
        self.data_len = npz_file['data_len'].item()
        self.keypoints_info = npz_file['keypoints_info'].item()
        self.non_sliced_data = None
        self.npz_file = None

    def __del__(self):
        if self.npz_file is not None:
            self.npz_file.close()

    def get_item(self, index, required_keys: List[str] = []):
        if self.npz_file is None:
            self.npz_file = np.load(self.npz_path, allow_pickle=True)
        cache_key = str(int(index / self.slice_size))
        base_data = self.npz_file[cache_key].item()
        base_data.update(self.keypoints_info)
        for key in required_keys:
            non_sliced_value = self.get_non_sliced_data(key)
            if isinstance(non_sliced_value, dict) and\
                    key in base_data and\
                    isinstance(base_data[key], dict):
                base_data[key].update(non_sliced_value)
            else:
                base_data[key] = non_sliced_value
        ret_human_data = HumanData.new(source_dict=base_data)
        # data in cache is compressed
        ret_human_data.__keypoints_compressed__ = True
        # set missing values and attributes by default method
        ret_human_data.__set_default_values__()
        return ret_human_data

    def get_non_sliced_data(self, key: str):
        if self.non_sliced_data is None:
            if self.npz_file is None:
                npz_file = np.load(self.npz_path, allow_pickle=True)
                self.non_sliced_data = npz_file['non_sliced_data'].item()
            else:
                self.non_sliced_data = self.npz_file['non_sliced_data'].item()
        return self.non_sliced_data[key]


class HumanDataCacheWriter():
    def __init__(self,
                 slice_size: int,
                 data_len: int,
                 keypoints_info: dict,
                 non_sliced_data: dict,
                 key_strict: bool = True):
        self.slice_size = slice_size
        self.data_len = data_len
        self.keypoints_info = keypoints_info
        self.non_sliced_data = non_sliced_data
        self.sliced_data = {}
        self.key_strict = key_strict

    def update_sliced_dict(self, sliced_dict):
        self.sliced_data.update(sliced_dict)

    def dump(self, npz_path: str, overwrite: bool = True):
        """Dump keys and items to an npz file.

        Args:
            npz_path (str):
                Path to a dumped npz file.
            overwrite (bool, optional):
                Whether to overwrite if there is already a file.
                Defaults to True.

        Raises:
            ValueError:
                npz_path does not end with '.npz'.
            FileExistsError:
                When overwrite is False and file exists.
        """
        if not check_path_suffix(npz_path, ['.npz']):
            raise ValueError('Not an npz file.')
        if not overwrite:
            if check_path_existence(npz_path, 'file') == Existence.FileExist:
                raise FileExistsError
        dict_to_dump = {
            'slice_size': self.slice_size,
            'data_len': self.data_len,
            'keypoints_info': self.keypoints_info,
            'non_sliced_data': self.non_sliced_data,
            'key_strict': self.key_strict,
        }
        dict_to_dump.update(self.sliced_data)
        np.savez_compressed(npz_path, **dict_to_dump)