File size: 5,461 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import numpy as np
from mmengine.dataset import BaseDataset, force_full_init

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class KFoldDataset:
    """A wrapper of dataset for K-Fold cross-validation.

    K-Fold cross-validation divides all the samples in groups of samples,
    called folds, of almost equal sizes. And we use k-1 of folds to do training
    and use the fold left to do validation.

    Args:
        dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be
            divided
        fold (int): The fold used to do validation. Defaults to 0.
        num_splits (int): The number of all folds. Defaults to 5.
        test_mode (bool): Use the training dataset or validation dataset.
            Defaults to False.
        seed (int, optional): The seed to shuffle the dataset before splitting.
            If None, not shuffle the dataset. Defaults to None.
    """

    def __init__(self,
                 dataset,
                 fold=0,
                 num_splits=5,
                 test_mode=False,
                 seed=None):
        if isinstance(dataset, dict):
            self.dataset = DATASETS.build(dataset)
            # Init the dataset wrapper lazily according to the dataset setting.
            lazy_init = dataset.get('lazy_init', False)
        elif isinstance(dataset, BaseDataset):
            self.dataset = dataset
        else:
            raise TypeError(f'Unsupported dataset type {type(dataset)}.')

        self._metainfo = getattr(self.dataset, 'metainfo', {})
        self.fold = fold
        self.num_splits = num_splits
        self.test_mode = test_mode
        self.seed = seed

        self._fully_initialized = False
        if not lazy_init:
            self.full_init()

    @property
    def metainfo(self) -> dict:
        """Get the meta information of ``self.dataset``.

        Returns:
            dict: Meta information of the dataset.
        """
        # Prevent `self._metainfo` from being modified by outside.
        return copy.deepcopy(self._metainfo)

    def full_init(self):
        """fully initialize the dataset."""
        if self._fully_initialized:
            return

        self.dataset.full_init()
        ori_len = len(self.dataset)
        indices = list(range(ori_len))
        if self.seed is not None:
            rng = np.random.default_rng(self.seed)
            rng.shuffle(indices)

        test_start = ori_len * self.fold // self.num_splits
        test_end = ori_len * (self.fold + 1) // self.num_splits
        if self.test_mode:
            indices = indices[test_start:test_end]
        else:
            indices = indices[:test_start] + indices[test_end:]

        self._ori_indices = indices
        self.dataset = self.dataset.get_subset(indices)

        self._fully_initialized = True

    @force_full_init
    def _get_ori_dataset_idx(self, idx: int) -> int:
        """Convert global idx to local index.

        Args:
            idx (int): Global index of ``KFoldDataset``.

        Returns:
            int: The original index in the whole dataset.
        """
        return self._ori_indices[idx]

    @force_full_init
    def get_data_info(self, idx: int) -> dict:
        """Get annotation by index.

        Args:
            idx (int): Global index of ``KFoldDataset``.

        Returns:
            dict: The idx-th annotation of the datasets.
        """
        return self.dataset.get_data_info(idx)

    @force_full_init
    def __len__(self):
        return len(self.dataset)

    @force_full_init
    def __getitem__(self, idx):
        return self.dataset[idx]

    @force_full_init
    def get_cat_ids(self, idx):
        return self.dataset.get_cat_ids(idx)

    @force_full_init
    def get_gt_labels(self):
        return self.dataset.get_gt_labels()

    @property
    def CLASSES(self):
        """Return all categories names."""
        return self._metainfo.get('classes', None)

    @property
    def class_to_idx(self):
        """Map mapping class name to class index.

        Returns:
            dict: mapping from class name to class index.
        """

        return {cat: i for i, cat in enumerate(self.CLASSES)}

    def __repr__(self):
        """Print the basic information of the dataset.

        Returns:
            str: Formatted string.
        """
        head = 'Dataset ' + self.__class__.__name__
        body = []
        type_ = 'test' if self.test_mode else 'training'
        body.append(f'Type: \t{type_}')
        body.append(f'Seed: \t{self.seed}')

        def ordinal(n):
            # Copy from https://codegolf.stackexchange.com/a/74047
            suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4]
            return f'{n}{suffix}'

        body.append(
            f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold')
        if self._fully_initialized:
            body.append(f'Number of samples: \t{self.__len__()}')
        else:
            body.append("Haven't been initialized")

        if self.CLASSES is not None:
            body.append(f'Number of categories: \t{len(self.CLASSES)}')
        else:
            body.append('The `CLASSES` meta info is not set.')

        body.append(
            f'Original dataset type:\t{self.dataset.__class__.__name__}')

        lines = [head] + [' ' * 4 + line for line in body]
        return '\n'.join(lines)