TTP / mmpretrain /datasets /infographic_vqa.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class InfographicVQA(BaseDataset):
"""Infographic VQA dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``.
data_prefix (str): The directory of images.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)
annotations = annotations['data']
data_list = []
for ann in annotations:
# ann example
# {
# "questionId": 98313,
# "question": "Which social platform has heavy female audience?",
# "image_local_name": "37313.jpeg",
# "image_url": "https://xxx.png",
# "ocr_output_file": "37313.json",
# "answers": [
# "pinterest"
# ],
# "data_split": "val"
# }
data_info = dict()
data_info['question'] = ann['question']
data_info['img_path'] = mmengine.join_path(
self.data_prefix['img_path'], ann['image_local_name'])
if 'answers' in ann.keys(): # test splits do not include gt
data_info['gt_answer'] = ann['answers']
data_list.append(data_info)
return data_list