File size: 1,989 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class InfographicVQA(BaseDataset):
    """Infographic VQA dataset.

    Args:
        data_root (str): The root directory for ``data_prefix``, ``ann_file``.
        data_prefix (str): The directory of images.
        ann_file (str, optional): Annotation file path for training and
            validation. Defaults to an empty string.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.
    """

    def __init__(self,
                 data_root: str,
                 data_prefix: str,
                 ann_file: str = '',
                 **kwarg):
        super().__init__(
            data_root=data_root,
            data_prefix=dict(img_path=data_prefix),
            ann_file=ann_file,
            **kwarg,
        )

    def load_data_list(self) -> List[dict]:
        """Load data list."""
        annotations = mmengine.load(self.ann_file)
        annotations = annotations['data']

        data_list = []
        for ann in annotations:
            # ann example
            # {
            # "questionId": 98313,
            # "question": "Which social platform has heavy female audience?",
            # "image_local_name": "37313.jpeg",
            # "image_url": "https://xxx.png",
            # "ocr_output_file": "37313.json",
            # "answers": [
            #     "pinterest"
            # ],
            # "data_split": "val"
            # }
            data_info = dict()
            data_info['question'] = ann['question']
            data_info['img_path'] = mmengine.join_path(
                self.data_prefix['img_path'], ann['image_local_name'])
            if 'answers' in ann.keys():  # test splits do not include gt
                data_info['gt_answer'] = ann['answers']
            data_list.append(data_info)

        return data_list