TTP / mmpretrain /datasets /textvqa.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List
import mmengine
from mmengine.dataset import BaseDataset
from mmpretrain.registry import DATASETS
@DATASETS.register_module()
class TextVQA(BaseDataset):
"""TextVQA dataset.
val image:
https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
test image:
https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
val json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json
test json:
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json
folder structure:
data/textvqa
β”œβ”€β”€ annotations
β”‚ β”œβ”€β”€ TextVQA_0.5.1_test.json
β”‚ └── TextVQA_0.5.1_val.json
└── images
β”œβ”€β”€ test_images
└── train_images
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
question_file (str): Question file path.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""
def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)['data']
data_list = []
for ann in annotations:
# ann example
# {
# 'question': 'what is the brand of...is camera?',
# 'image_id': '003a8ae2ef43b901',
# 'image_classes': [
# 'Cassette deck', 'Printer', ...
# ],
# 'flickr_original_url': 'https://farm2.static...04a6_o.jpg',
# 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg',
# 'image_width': 1024,
# 'image_height': 664,
# 'answers': [
# 'nous les gosses',
# 'dakota',
# 'clos culombu',
# 'dakota digital' ...
# ],
# 'question_tokens':
# ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'],
# 'question_id': 34602,
# 'set_name': 'val'
# }
data_info = dict(question=ann['question'])
data_info['question_id'] = ann['question_id']
data_info['image_id'] = ann['image_id']
img_path = mmengine.join_path(self.data_prefix['img_path'],
ann['image_id'] + '.jpg')
data_info['img_path'] = img_path
data_info['question_id'] = ann['question_id']
if 'answers' in ann:
answers = [item for item in ann.pop('answers')]
count = Counter(answers)
answer_weight = [i / len(answers) for i in count.values()]
data_info['gt_answer'] = list(count.keys())
data_info['gt_answer_weight'] = answer_weight
data_list.append(data_info)
return data_list