File size: 4,553 Bytes
f549064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) OpenMMLab. All rights reserved.
import os
import os.path as osp
from typing import Optional, Sequence

from mmengine.dist import is_main_process
from mmengine.evaluator import BaseMetric
from mmengine.fileio import dump
from mmengine.logging import MMLogger
from mmengine.structures import InstanceData

from mmdet.registry import METRICS


@METRICS.register_module()
class DumpProposals(BaseMetric):
    """Dump proposals pseudo metric.

    Args:
        output_dir (str): The root directory for ``proposals_file``.
            Defaults to ''.
        proposals_file (str): Proposals file path. Defaults to 'proposals.pkl'.
        num_max_proposals (int, optional): Maximum number of proposals to dump.
            If not specified, all proposals will be dumped.
        file_client_args (dict): Arguments to instantiate a FileClient.
            See :class:`mmengine.fileio.FileClient` for details.
            Defaults to ``dict(backend='disk')``.
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Defaults to None.
    """

    default_prefix: Optional[str] = 'dump_proposals'

    def __init__(self,
                 output_dir: str = '',
                 proposals_file: str = 'proposals.pkl',
                 num_max_proposals: Optional[int] = None,
                 file_client_args: dict = dict(backend='disk'),
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None) -> None:
        super().__init__(collect_device=collect_device, prefix=prefix)
        self.num_max_proposals = num_max_proposals
        # TODO: update after mmengine finish refactor fileio.
        self.file_client_args = file_client_args
        self.output_dir = output_dir
        assert proposals_file.endswith(('.pkl', '.pickle')), \
            'The output file must be a pkl file.'

        self.proposals_file = os.path.join(self.output_dir, proposals_file)
        if is_main_process():
            os.makedirs(self.output_dir, exist_ok=True)

    def process(self, data_batch: Sequence[dict],
                data_samples: Sequence[dict]) -> None:
        """Process one batch of data samples and predictions. The processed
        results should be stored in ``self.results``, which will be used to
        compute the metrics when all batches have been processed.

        Args:
            data_batch (dict): A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of data samples that
                contain annotations and predictions.
        """
        for data_sample in data_samples:
            pred = data_sample['pred_instances']
            # `bboxes` is sorted by `scores`
            ranked_scores, rank_inds = pred['scores'].sort(descending=True)
            ranked_bboxes = pred['bboxes'][rank_inds, :]

            ranked_bboxes = ranked_bboxes.cpu().numpy()
            ranked_scores = ranked_scores.cpu().numpy()

            pred_instance = InstanceData()
            pred_instance.bboxes = ranked_bboxes
            pred_instance.scores = ranked_scores
            if self.num_max_proposals is not None:
                pred_instance = pred_instance[:self.num_max_proposals]

            img_path = data_sample['img_path']
            # `file_name` is the key to obtain the proposals from the
            # `proposals_list`.
            file_name = osp.join(
                osp.split(osp.split(img_path)[0])[-1],
                osp.split(img_path)[-1])
            result = {file_name: pred_instance}
            self.results.append(result)

    def compute_metrics(self, results: list) -> dict:
        """Dump the processed results.

        Args:
            results (list): The processed results of each batch.

        Returns:
            dict: An empty dict.
        """
        logger: MMLogger = MMLogger.get_current_instance()
        dump_results = {}
        for result in results:
            dump_results.update(result)
        dump(
            dump_results,
            file=self.proposals_file,
            file_client_args=self.file_client_args)
        logger.info(f'Results are saved at {self.proposals_file}')
        return {}