# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from typing import List, Optional from mmengine.dataset import BaseDataset from mmengine.fileio import FileClient, load from mmengine.utils import is_abs from ..registry import DATASETS @DATASETS.register_module() class BaseDetDataset(BaseDataset): """Base dataset for detection. Args: proposal_file (str, optional): Proposals file path. Defaults to None. file_client_args (dict): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. Defaults to ``dict(backend='disk')``. """ def __init__(self, *args, seg_map_suffix: str = '.png', proposal_file: Optional[str] = None, file_client_args: dict = dict(backend='disk'), **kwargs) -> None: self.seg_map_suffix = seg_map_suffix self.proposal_file = proposal_file self.file_client_args = file_client_args self.file_client = FileClient(**file_client_args) super().__init__(*args, **kwargs) def full_init(self) -> None: """Load annotation file and set ``BaseDataset._fully_initialized`` to True. If ``lazy_init=False``, ``full_init`` will be called during the instantiation and ``self._fully_initialized`` will be set to True. If ``obj._fully_initialized=False``, the class method decorated by ``force_full_init`` will call ``full_init`` automatically. Several steps to initialize annotation: - load_data_list: Load annotations from annotation file. - load_proposals: Load proposals from proposal file, if `self.proposal_file` is not None. - filter data information: Filter annotations according to filter_cfg. - slice_data: Slice dataset according to ``self._indices`` - serialize_data: Serialize ``self.data_list`` if ``self.serialize_data`` is True. """ if self._fully_initialized: return # load data information self.data_list = self.load_data_list() # get proposals from file if self.proposal_file is not None: self.load_proposals() # filter illegal data, such as data that has no annotations. self.data_list = self.filter_data() # Get subset data according to indices. if self._indices is not None: self.data_list = self._get_unserialized_subset(self._indices) # serialize data_list if self.serialize_data: self.data_bytes, self.data_address = self._serialize_data() self._fully_initialized = True def load_proposals(self) -> None: """Load proposals from proposals file. The `proposals_list` should be a dict[img_path: proposals] with the same length as `data_list`. And the `proposals` should be a `dict` or :obj:`InstanceData` usually contains following keys. - bboxes (np.ndarry): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). - scores (np.ndarry): Classification scores, has a shape (num_instance, ). """ # TODO: Add Unit Test after fully support Dump-Proposal Metric if not is_abs(self.proposal_file): self.proposal_file = osp.join(self.data_root, self.proposal_file) proposals_list = load( self.proposal_file, file_client_args=self.file_client_args) assert len(self.data_list) == len(proposals_list) for data_info in self.data_list: img_path = data_info['img_path'] # `file_name` is the key to obtain the proposals from the # `proposals_list`. file_name = osp.join( osp.split(osp.split(img_path)[0])[-1], osp.split(img_path)[-1]) proposals = proposals_list[file_name] data_info['proposals'] = proposals def get_cat_ids(self, idx: int) -> List[int]: """Get COCO category ids by index. Args: idx (int): Index of data. Returns: List[int]: All categories in the image of specified index. """ instances = self.get_data_info(idx)['instances'] return [instance['bbox_label'] for instance in instances]