TTP / mmpretrain /datasets /scienceqa.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, List, Sequence
import mmengine
from mmengine.dataset import BaseDataset
from mmengine.fileio import get_file_backend
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class ScienceQA(BaseDataset):
"""ScienceQA dataset.
This dataset is used to load the multimodal data of ScienceQA dataset.
Args:
data_root (str): The root directory for ``data_prefix`` and
``ann_file``.
split (str): The split of dataset. Options: ``train``, ``val``,
``test``, ``trainval``, ``minival``, and ``minitest``.
split_file (str): The split file of dataset, which contains the
ids of data samples in the split.
ann_file (str): Annotation file path.
image_only (bool): Whether only to load data with image. Defaults to
False.
data_prefix (dict): Prefix for data field. Defaults to
``dict(img_path='')``.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
split: str,
split_file: str,
ann_file: str,
image_only: bool = False,
data_prefix: dict = dict(img_path=''),
pipeline: Sequence[Callable] = (),
**kwargs):
assert split in [
'train', 'val', 'test', 'trainval', 'minival', 'minitest'
], f'Invalid split {split}'
self.split = split
self.split_file = os.path.join(data_root, split_file)
self.image_only = image_only
super().__init__(
data_root=data_root,
ann_file=ann_file,
data_prefix=data_prefix,
pipeline=pipeline,
**kwargs)
def load_data_list(self) -> List[dict]:
"""Load data list."""
img_prefix = self.data_prefix['img_path']
annotations = mmengine.load(self.ann_file)
current_data_split = mmengine.load(self.split_file)[self.split] # noqa
file_backend = get_file_backend(img_prefix)
data_list = []
for data_id in current_data_split:
ann = annotations[data_id]
if self.image_only and ann['image'] is None:
continue
data_info = {
'image_id':
data_id,
'question':
ann['question'],
'choices':
ann['choices'],
'gt_answer':
ann['answer'],
'hint':
ann['hint'],
'image_name':
ann['image'],
'task':
ann['task'],
'grade':
ann['grade'],
'subject':
ann['subject'],
'topic':
ann['topic'],
'category':
ann['category'],
'skill':
ann['skill'],
'lecture':
ann['lecture'],
'solution':
ann['solution'],
'split':
ann['split'],
'img_path':
file_backend.join_path(img_prefix, data_id, ann['image'])
if ann['image'] is not None else None,
'has_image':
True if ann['image'] is not None else False,
}
data_list.append(data_info)
return data_list