File size: 5,515 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from collections import OrderedDict
from os import PathLike
from typing import List, Sequence, Union

from mmengine import get_file_backend

from mmpretrain.registry import DATASETS, TRANSFORMS
from .base_dataset import BaseDataset


def expanduser(data_prefix):
    if isinstance(data_prefix, (str, PathLike)):
        return osp.expanduser(data_prefix)
    else:
        return data_prefix


@DATASETS.register_module()
class COCORetrieval(BaseDataset):
    """COCO Retrieval dataset.

    COCO (Common Objects in Context): The COCO dataset contains more than
    330K images,each of which has approximately 5 descriptive annotations.
    This dataset was releasedin collaboration between Microsoft and Carnegie
    Mellon University

    COCO_2014 dataset directory: ::

        COCO_2014
        β”œβ”€β”€ val2014
        β”œβ”€β”€ train2014
        β”œβ”€β”€ annotations
                 β”œβ”€β”€ instances_train2014.json
                 β”œβ”€β”€ instances_val2014.json
                 β”œβ”€β”€ person_keypoints_train2014.json
                 β”œβ”€β”€ person_keypoints_val2014.json
                 β”œβ”€β”€ captions_train2014.json
                 β”œβ”€β”€ captions_val2014.json

    Args:
        ann_file (str): Annotation file path.
        test_mode (bool): Whether dataset is used for evaluation. This will
            decide the annotation format in data list annotations.
            Defaults to False.
        data_root (str): The root directory for ``data_prefix`` and
            ``ann_file``. Defaults to ''.
        data_prefix (str | dict): Prefix for training data. Defaults to ''.
        pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.

    Examples:
        >>> from mmpretrain.datasets import COCORetrieval
        >>> train_dataset=COCORetrieval(data_root='coco2014/')
        >>> train_dataset
        Dataset COCORetrieval
            Number of samples: 	414113
            Annotation file:  /coco2014/annotations/captions_train2014.json
            Prefix of images:  /coco2014/
        >>> from mmpretrain.datasets import COCORetrieval
        >>> val_dataset = COCORetrieval(data_root='coco2014/')
        >>> val_dataset
         Dataset COCORetrieval
             Number of samples: 	202654
             Annotation file: 	/coco2014/annotations/captions_val2014.json
             Prefix of images: 	/coco2014/
    """

    def __init__(self,
                 ann_file: str,
                 test_mode: bool = False,
                 data_prefix: Union[str, dict] = '',
                 data_root: str = '',
                 pipeline: Sequence = (),
                 **kwargs):

        if isinstance(data_prefix, str):
            data_prefix = dict(img_path=expanduser(data_prefix))

        ann_file = expanduser(ann_file)
        transforms = []
        for transform in pipeline:
            if isinstance(transform, dict):
                transforms.append(TRANSFORMS.build(transform))
            else:
                transforms.append(transform)

        super().__init__(
            data_root=data_root,
            data_prefix=data_prefix,
            test_mode=test_mode,
            pipeline=transforms,
            ann_file=ann_file,
            **kwargs,
        )

    def load_data_list(self) -> List[dict]:
        """Load data list."""
        # get file backend
        img_prefix = self.data_prefix['img_path']
        file_backend = get_file_backend(img_prefix)

        anno_info = json.load(open(self.ann_file, 'r'))
        # mapping img_id to img filename
        img_dict = OrderedDict()
        for idx, img in enumerate(anno_info['images']):
            if img['id'] not in img_dict:
                img_rel_path = img['coco_url'].rsplit('/', 2)[-2:]
                img_path = file_backend.join_path(img_prefix, *img_rel_path)

                # create new idx for image
                img_dict[img['id']] = dict(
                    ori_id=img['id'],
                    image_id=idx,  # will be used for evaluation
                    img_path=img_path,
                    text=[],
                    gt_text_id=[],
                    gt_image_id=[],
                )

        train_list = []
        for idx, anno in enumerate(anno_info['annotations']):
            anno['text'] = anno.pop('caption')
            anno['ori_id'] = anno.pop('id')
            anno['text_id'] = idx  # will be used for evaluation
            # 1. prepare train data list item
            train_data = anno.copy()
            train_image = img_dict[train_data['image_id']]
            train_data['img_path'] = train_image['img_path']
            train_data['image_ori_id'] = train_image['ori_id']
            train_data['image_id'] = train_image['image_id']
            train_data['is_matched'] = True
            train_list.append(train_data)
            # 2. prepare eval data list item based on img dict
            img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id'])
            img_dict[anno['image_id']]['text'].append(anno['text'])
            img_dict[anno['image_id']]['gt_image_id'].append(
                train_image['image_id'])

        self.img_size = len(img_dict)
        self.text_size = len(anno_info['annotations'])

        # return needed format data list
        if self.test_mode:
            return list(img_dict.values())
        return train_list