Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import copy | |
import numpy as np | |
from mmengine.dataset import BaseDataset, force_full_init | |
from mmpretrain.registry import DATASETS | |
class KFoldDataset: | |
"""A wrapper of dataset for K-Fold cross-validation. | |
K-Fold cross-validation divides all the samples in groups of samples, | |
called folds, of almost equal sizes. And we use k-1 of folds to do training | |
and use the fold left to do validation. | |
Args: | |
dataset (:obj:`mmengine.dataset.BaseDataset` | dict): The dataset to be | |
divided | |
fold (int): The fold used to do validation. Defaults to 0. | |
num_splits (int): The number of all folds. Defaults to 5. | |
test_mode (bool): Use the training dataset or validation dataset. | |
Defaults to False. | |
seed (int, optional): The seed to shuffle the dataset before splitting. | |
If None, not shuffle the dataset. Defaults to None. | |
""" | |
def __init__(self, | |
dataset, | |
fold=0, | |
num_splits=5, | |
test_mode=False, | |
seed=None): | |
if isinstance(dataset, dict): | |
self.dataset = DATASETS.build(dataset) | |
# Init the dataset wrapper lazily according to the dataset setting. | |
lazy_init = dataset.get('lazy_init', False) | |
elif isinstance(dataset, BaseDataset): | |
self.dataset = dataset | |
else: | |
raise TypeError(f'Unsupported dataset type {type(dataset)}.') | |
self._metainfo = getattr(self.dataset, 'metainfo', {}) | |
self.fold = fold | |
self.num_splits = num_splits | |
self.test_mode = test_mode | |
self.seed = seed | |
self._fully_initialized = False | |
if not lazy_init: | |
self.full_init() | |
def metainfo(self) -> dict: | |
"""Get the meta information of ``self.dataset``. | |
Returns: | |
dict: Meta information of the dataset. | |
""" | |
# Prevent `self._metainfo` from being modified by outside. | |
return copy.deepcopy(self._metainfo) | |
def full_init(self): | |
"""fully initialize the dataset.""" | |
if self._fully_initialized: | |
return | |
self.dataset.full_init() | |
ori_len = len(self.dataset) | |
indices = list(range(ori_len)) | |
if self.seed is not None: | |
rng = np.random.default_rng(self.seed) | |
rng.shuffle(indices) | |
test_start = ori_len * self.fold // self.num_splits | |
test_end = ori_len * (self.fold + 1) // self.num_splits | |
if self.test_mode: | |
indices = indices[test_start:test_end] | |
else: | |
indices = indices[:test_start] + indices[test_end:] | |
self._ori_indices = indices | |
self.dataset = self.dataset.get_subset(indices) | |
self._fully_initialized = True | |
def _get_ori_dataset_idx(self, idx: int) -> int: | |
"""Convert global idx to local index. | |
Args: | |
idx (int): Global index of ``KFoldDataset``. | |
Returns: | |
int: The original index in the whole dataset. | |
""" | |
return self._ori_indices[idx] | |
def get_data_info(self, idx: int) -> dict: | |
"""Get annotation by index. | |
Args: | |
idx (int): Global index of ``KFoldDataset``. | |
Returns: | |
dict: The idx-th annotation of the datasets. | |
""" | |
return self.dataset.get_data_info(idx) | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, idx): | |
return self.dataset[idx] | |
def get_cat_ids(self, idx): | |
return self.dataset.get_cat_ids(idx) | |
def get_gt_labels(self): | |
return self.dataset.get_gt_labels() | |
def CLASSES(self): | |
"""Return all categories names.""" | |
return self._metainfo.get('classes', None) | |
def class_to_idx(self): | |
"""Map mapping class name to class index. | |
Returns: | |
dict: mapping from class name to class index. | |
""" | |
return {cat: i for i, cat in enumerate(self.CLASSES)} | |
def __repr__(self): | |
"""Print the basic information of the dataset. | |
Returns: | |
str: Formatted string. | |
""" | |
head = 'Dataset ' + self.__class__.__name__ | |
body = [] | |
type_ = 'test' if self.test_mode else 'training' | |
body.append(f'Type: \t{type_}') | |
body.append(f'Seed: \t{self.seed}') | |
def ordinal(n): | |
# Copy from https://codegolf.stackexchange.com/a/74047 | |
suffix = 'tsnrhtdd'[(n // 10 % 10 != 1) * (n % 10 < 4) * n % 10::4] | |
return f'{n}{suffix}' | |
body.append( | |
f'Fold: \t{ordinal(self.fold+1)} of {self.num_splits}-fold') | |
if self._fully_initialized: | |
body.append(f'Number of samples: \t{self.__len__()}') | |
else: | |
body.append("Haven't been initialized") | |
if self.CLASSES is not None: | |
body.append(f'Number of categories: \t{len(self.CLASSES)}') | |
else: | |
body.append('The `CLASSES` meta info is not set.') | |
body.append( | |
f'Original dataset type:\t{self.dataset.__class__.__name__}') | |
lines = [head] + [' ' * 4 + line for line in body] | |
return '\n'.join(lines) | |