# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from typing import List import mmengine from mmengine.dataset import BaseDataset from mmpretrain.registry import DATASETS @DATASETS.register_module() class OCRVQA(BaseDataset): """OCR-VQA dataset. Args: data_root (str): The root directory for ``data_prefix``, ``ann_file`` and ``question_file``. data_prefix (str): The directory of images. ann_file (str): Annotation file path for training and validation. split (str): 'train', 'val' or 'test'. **kwargs: Other keyword arguments in :class:`BaseDataset`. """ def __init__(self, data_root: str, data_prefix: str, ann_file: str, split: str, **kwarg): assert split in ['train', 'val', 'test'], \ '`split` must be train, val or test' self.split = split super().__init__( data_root=data_root, data_prefix=dict(img_path=data_prefix), ann_file=ann_file, **kwarg, ) def load_data_list(self) -> List[dict]: """Load data list.""" split_dict = {1: 'train', 2: 'val', 3: 'test'} annotations = mmengine.load(self.ann_file) # ann example # "761183272": { # "imageURL": \ # "http://ecx.images-amazon.com/images/I/61Y5cOdHJbL.jpg", # "questions": [ # "Who wrote this book?", # "What is the title of this book?", # "What is the genre of this book?", # "Is this a games related book?", # "What is the year printed on this calendar?"], # "answers": [ # "Sandra Boynton", # "Mom's Family Wall Calendar 2016", # "Calendars", # "No", # "2016"], # "title": "Mom's Family Wall Calendar 2016", # "authorName": "Sandra Boynton", # "genre": "Calendars", # "split": 1 # }, data_list = [] for key, ann in annotations.items(): if self.split != split_dict[ann['split']]: continue extension = osp.splitext(ann['imageURL'])[1] if extension not in ['.jpg', '.png']: continue img_path = mmengine.join_path(self.data_prefix['img_path'], key + extension) for question, answer in zip(ann['questions'], ann['answers']): data_info = {} data_info['img_path'] = img_path data_info['question'] = question data_info['gt_answer'] = answer data_info['gt_answer_weight'] = [1.0] data_info['imageURL'] = ann['imageURL'] data_info['title'] = ann['title'] data_info['authorName'] = ann['authorName'] data_info['genre'] = ann['genre'] data_list.append(data_info) return data_list