File size: 4,557 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os
import tempfile
from typing import List, Optional

from mmengine.evaluator import BaseMetric
from mmengine.utils import track_iter_progress
from pycocotools.coco import COCO

from mmdet.registry import METRICS

try:
    from pycocoevalcap.eval import COCOEvalCap
except ImportError:
    COCOEvalCap = None


@METRICS.register_module()
class COCOCaptionMetric(BaseMetric):
    """Coco Caption evaluation wrapper.

    Save the generated captions and transform into coco format.
    Calling COCO API for caption metrics.

    Args:
        ann_file (str): the path for the COCO format caption ground truth
            json file, load for evaluations.
        collect_device (str): Device name used for collecting results from
            different ranks during distributed training. Must be 'cpu' or
            'gpu'. Defaults to 'cpu'.
        prefix (str, optional): The prefix that will be added in the metric
            names to disambiguate homonymous metrics of different evaluators.
            If prefix is not provided in the argument, self.default_prefix
            will be used instead. Should be modified according to the
            `retrieval_type` for unambiguous results. Defaults to TR.
    """

    def __init__(self,
                 ann_file: str,
                 collect_device: str = 'cpu',
                 prefix: Optional[str] = None):
        if COCOEvalCap is None:
            raise RuntimeError(
                'COCOEvalCap is not installed, please install it by: '
                'pip install pycocoevalcap')

        super().__init__(collect_device=collect_device, prefix=prefix)
        self.ann_file = ann_file

    def process(self, data_batch, data_samples):
        """Process one batch of data samples.

        The processed results should be stored in ``self.results``, which will
        be used to computed the metrics when all batches have been processed.

        Args:
            data_batch: A batch of data from the dataloader.
            data_samples (Sequence[dict]): A batch of outputs from the model.
        """

        for data_sample in data_samples:
            result = dict()

            result['caption'] = data_sample['pred_caption']
            result['image_id'] = int(data_sample['img_id'])

            # Save the result to `self.results`.
            self.results.append(result)

    def compute_metrics(self, results: List):
        """Compute the metrics from processed results.

        Args:
            results (dict): The processed results of each batch.

        Returns:
            Dict: The computed metrics. The keys are the names of the metrics,
            and the values are corresponding results.
        """
        # NOTICE: don't access `self.results` from the method.

        with tempfile.TemporaryDirectory() as temp_dir:

            eval_result_file = save_result(
                result=results,
                result_dir=temp_dir,
                filename='caption_pred',
                remove_duplicate='image_id',
            )

            coco_val = coco_caption_eval(eval_result_file, self.ann_file)

        return coco_val


def save_result(result, result_dir, filename, remove_duplicate=''):
    """Saving predictions as json file for evaluation."""
    # combine results from all processes
    if remove_duplicate:
        result_new = []
        id_list = []
        for res in track_iter_progress(result):
            if res[remove_duplicate] not in id_list:
                id_list.append(res[remove_duplicate])
                result_new.append(res)
        result = result_new

    final_result_file_url = os.path.join(result_dir, '%s.json' % filename)
    print(f'result file saved to {final_result_file_url}')
    json.dump(result, open(final_result_file_url, 'w'))

    return final_result_file_url


def coco_caption_eval(results_file, ann_file):
    """Evaluation between gt json and prediction json files."""
    # create coco object and coco_result object
    coco = COCO(ann_file)
    coco_result = coco.loadRes(results_file)

    # create coco_eval object by taking coco and coco_result
    coco_eval = COCOEvalCap(coco, coco_result)

    # make sure the image ids are the same
    coco_eval.params['image_id'] = coco_result.getImgIds()

    # This will take some times at the first run
    coco_eval.evaluate()

    # print output evaluation scores
    for metric, score in coco_eval.eval.items():
        print(f'{metric}: {score:.3f}')

    return coco_eval.eval