Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import Counter | |
from typing import List | |
import mmengine | |
from mmengine.dataset import BaseDataset | |
from mmpretrain.registry import DATASETS | |
class TextVQA(BaseDataset): | |
"""TextVQA dataset. | |
val image: | |
https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip | |
test image: | |
https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip | |
val json: | |
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json | |
test json: | |
https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_test.json | |
folder structure: | |
data/textvqa | |
βββ annotations | |
β βββ TextVQA_0.5.1_test.json | |
β βββ TextVQA_0.5.1_val.json | |
βββ images | |
βββ test_images | |
βββ train_images | |
Args: | |
data_root (str): The root directory for ``data_prefix``, ``ann_file`` | |
and ``question_file``. | |
data_prefix (str): The directory of images. | |
question_file (str): Question file path. | |
ann_file (str, optional): Annotation file path for training and | |
validation. Defaults to an empty string. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
data_prefix: str, | |
ann_file: str = '', | |
**kwarg): | |
super().__init__( | |
data_root=data_root, | |
data_prefix=dict(img_path=data_prefix), | |
ann_file=ann_file, | |
**kwarg, | |
) | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
annotations = mmengine.load(self.ann_file)['data'] | |
data_list = [] | |
for ann in annotations: | |
# ann example | |
# { | |
# 'question': 'what is the brand of...is camera?', | |
# 'image_id': '003a8ae2ef43b901', | |
# 'image_classes': [ | |
# 'Cassette deck', 'Printer', ... | |
# ], | |
# 'flickr_original_url': 'https://farm2.static...04a6_o.jpg', | |
# 'flickr_300k_url': 'https://farm2.static...04a6_o.jpg', | |
# 'image_width': 1024, | |
# 'image_height': 664, | |
# 'answers': [ | |
# 'nous les gosses', | |
# 'dakota', | |
# 'clos culombu', | |
# 'dakota digital' ... | |
# ], | |
# 'question_tokens': | |
# ['what', 'is', 'the', 'brand', 'of', 'this', 'camera'], | |
# 'question_id': 34602, | |
# 'set_name': 'val' | |
# } | |
data_info = dict(question=ann['question']) | |
data_info['question_id'] = ann['question_id'] | |
data_info['image_id'] = ann['image_id'] | |
img_path = mmengine.join_path(self.data_prefix['img_path'], | |
ann['image_id'] + '.jpg') | |
data_info['img_path'] = img_path | |
data_info['question_id'] = ann['question_id'] | |
if 'answers' in ann: | |
answers = [item for item in ann.pop('answers')] | |
count = Counter(answers) | |
answer_weight = [i / len(answers) for i in count.values()] | |
data_info['gt_answer'] = list(count.keys()) | |
data_info['gt_answer_weight'] = answer_weight | |
data_list.append(data_info) | |
return data_list | |