Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import json | |
import os.path as osp | |
from collections import OrderedDict | |
from os import PathLike | |
from typing import List, Sequence, Union | |
from mmengine import get_file_backend | |
from mmpretrain.registry import DATASETS, TRANSFORMS | |
from .base_dataset import BaseDataset | |
def expanduser(data_prefix): | |
if isinstance(data_prefix, (str, PathLike)): | |
return osp.expanduser(data_prefix) | |
else: | |
return data_prefix | |
class COCORetrieval(BaseDataset): | |
"""COCO Retrieval dataset. | |
COCO (Common Objects in Context): The COCO dataset contains more than | |
330K images,each of which has approximately 5 descriptive annotations. | |
This dataset was releasedin collaboration between Microsoft and Carnegie | |
Mellon University | |
COCO_2014 dataset directory: :: | |
COCO_2014 | |
βββ val2014 | |
βββ train2014 | |
βββ annotations | |
βββ instances_train2014.json | |
βββ instances_val2014.json | |
βββ person_keypoints_train2014.json | |
βββ person_keypoints_val2014.json | |
βββ captions_train2014.json | |
βββ captions_val2014.json | |
Args: | |
ann_file (str): Annotation file path. | |
test_mode (bool): Whether dataset is used for evaluation. This will | |
decide the annotation format in data list annotations. | |
Defaults to False. | |
data_root (str): The root directory for ``data_prefix`` and | |
``ann_file``. Defaults to ''. | |
data_prefix (str | dict): Prefix for training data. Defaults to ''. | |
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
Examples: | |
>>> from mmpretrain.datasets import COCORetrieval | |
>>> train_dataset=COCORetrieval(data_root='coco2014/') | |
>>> train_dataset | |
Dataset COCORetrieval | |
Number of samples: 414113 | |
Annotation file: /coco2014/annotations/captions_train2014.json | |
Prefix of images: /coco2014/ | |
>>> from mmpretrain.datasets import COCORetrieval | |
>>> val_dataset = COCORetrieval(data_root='coco2014/') | |
>>> val_dataset | |
Dataset COCORetrieval | |
Number of samples: 202654 | |
Annotation file: /coco2014/annotations/captions_val2014.json | |
Prefix of images: /coco2014/ | |
""" | |
def __init__(self, | |
ann_file: str, | |
test_mode: bool = False, | |
data_prefix: Union[str, dict] = '', | |
data_root: str = '', | |
pipeline: Sequence = (), | |
**kwargs): | |
if isinstance(data_prefix, str): | |
data_prefix = dict(img_path=expanduser(data_prefix)) | |
ann_file = expanduser(ann_file) | |
transforms = [] | |
for transform in pipeline: | |
if isinstance(transform, dict): | |
transforms.append(TRANSFORMS.build(transform)) | |
else: | |
transforms.append(transform) | |
super().__init__( | |
data_root=data_root, | |
data_prefix=data_prefix, | |
test_mode=test_mode, | |
pipeline=transforms, | |
ann_file=ann_file, | |
**kwargs, | |
) | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
# get file backend | |
img_prefix = self.data_prefix['img_path'] | |
file_backend = get_file_backend(img_prefix) | |
anno_info = json.load(open(self.ann_file, 'r')) | |
# mapping img_id to img filename | |
img_dict = OrderedDict() | |
for idx, img in enumerate(anno_info['images']): | |
if img['id'] not in img_dict: | |
img_rel_path = img['coco_url'].rsplit('/', 2)[-2:] | |
img_path = file_backend.join_path(img_prefix, *img_rel_path) | |
# create new idx for image | |
img_dict[img['id']] = dict( | |
ori_id=img['id'], | |
image_id=idx, # will be used for evaluation | |
img_path=img_path, | |
text=[], | |
gt_text_id=[], | |
gt_image_id=[], | |
) | |
train_list = [] | |
for idx, anno in enumerate(anno_info['annotations']): | |
anno['text'] = anno.pop('caption') | |
anno['ori_id'] = anno.pop('id') | |
anno['text_id'] = idx # will be used for evaluation | |
# 1. prepare train data list item | |
train_data = anno.copy() | |
train_image = img_dict[train_data['image_id']] | |
train_data['img_path'] = train_image['img_path'] | |
train_data['image_ori_id'] = train_image['ori_id'] | |
train_data['image_id'] = train_image['image_id'] | |
train_data['is_matched'] = True | |
train_list.append(train_data) | |
# 2. prepare eval data list item based on img dict | |
img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id']) | |
img_dict[anno['image_id']]['text'].append(anno['text']) | |
img_dict[anno['image_id']]['gt_image_id'].append( | |
train_image['image_id']) | |
self.img_size = len(img_dict) | |
self.text_size = len(anno_info['annotations']) | |
# return needed format data list | |
if self.test_mode: | |
return list(img_dict.values()) | |
return train_list | |