Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os | |
from typing import Callable, List, Sequence | |
import mmengine | |
from mmengine.dataset import BaseDataset | |
from mmengine.fileio import get_file_backend | |
from mmpretrain.registry import DATASETS | |
class ScienceQA(BaseDataset): | |
"""ScienceQA dataset. | |
This dataset is used to load the multimodal data of ScienceQA dataset. | |
Args: | |
data_root (str): The root directory for ``data_prefix`` and | |
``ann_file``. | |
split (str): The split of dataset. Options: ``train``, ``val``, | |
``test``, ``trainval``, ``minival``, and ``minitest``. | |
split_file (str): The split file of dataset, which contains the | |
ids of data samples in the split. | |
ann_file (str): Annotation file path. | |
image_only (bool): Whether only to load data with image. Defaults to | |
False. | |
data_prefix (dict): Prefix for data field. Defaults to | |
``dict(img_path='')``. | |
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
split: str, | |
split_file: str, | |
ann_file: str, | |
image_only: bool = False, | |
data_prefix: dict = dict(img_path=''), | |
pipeline: Sequence[Callable] = (), | |
**kwargs): | |
assert split in [ | |
'train', 'val', 'test', 'trainval', 'minival', 'minitest' | |
], f'Invalid split {split}' | |
self.split = split | |
self.split_file = os.path.join(data_root, split_file) | |
self.image_only = image_only | |
super().__init__( | |
data_root=data_root, | |
ann_file=ann_file, | |
data_prefix=data_prefix, | |
pipeline=pipeline, | |
**kwargs) | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
img_prefix = self.data_prefix['img_path'] | |
annotations = mmengine.load(self.ann_file) | |
current_data_split = mmengine.load(self.split_file)[self.split] # noqa | |
file_backend = get_file_backend(img_prefix) | |
data_list = [] | |
for data_id in current_data_split: | |
ann = annotations[data_id] | |
if self.image_only and ann['image'] is None: | |
continue | |
data_info = { | |
'image_id': | |
data_id, | |
'question': | |
ann['question'], | |
'choices': | |
ann['choices'], | |
'gt_answer': | |
ann['answer'], | |
'hint': | |
ann['hint'], | |
'image_name': | |
ann['image'], | |
'task': | |
ann['task'], | |
'grade': | |
ann['grade'], | |
'subject': | |
ann['subject'], | |
'topic': | |
ann['topic'], | |
'category': | |
ann['category'], | |
'skill': | |
ann['skill'], | |
'lecture': | |
ann['lecture'], | |
'solution': | |
ann['solution'], | |
'split': | |
ann['split'], | |
'img_path': | |
file_backend.join_path(img_prefix, data_id, ann['image']) | |
if ann['image'] is not None else None, | |
'has_image': | |
True if ann['image'] is not None else False, | |
} | |
data_list.append(data_info) | |
return data_list | |