Spaces:
Runtime error
Runtime error
File size: 5,515 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
# Copyright (c) OpenMMLab. All rights reserved.
import json
import os.path as osp
from collections import OrderedDict
from os import PathLike
from typing import List, Sequence, Union
from mmengine import get_file_backend
from mmpretrain.registry import DATASETS, TRANSFORMS
from .base_dataset import BaseDataset
def expanduser(data_prefix):
if isinstance(data_prefix, (str, PathLike)):
return osp.expanduser(data_prefix)
else:
return data_prefix
@DATASETS.register_module()
class COCORetrieval(BaseDataset):
"""COCO Retrieval dataset.
COCO (Common Objects in Context): The COCO dataset contains more than
330K images,each of which has approximately 5 descriptive annotations.
This dataset was releasedin collaboration between Microsoft and Carnegie
Mellon University
COCO_2014 dataset directory: ::
COCO_2014
βββ val2014
βββ train2014
βββ annotations
βββ instances_train2014.json
βββ instances_val2014.json
βββ person_keypoints_train2014.json
βββ person_keypoints_val2014.json
βββ captions_train2014.json
βββ captions_val2014.json
Args:
ann_file (str): Annotation file path.
test_mode (bool): Whether dataset is used for evaluation. This will
decide the annotation format in data list annotations.
Defaults to False.
data_root (str): The root directory for ``data_prefix`` and
``ann_file``. Defaults to ''.
data_prefix (str | dict): Prefix for training data. Defaults to ''.
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
Examples:
>>> from mmpretrain.datasets import COCORetrieval
>>> train_dataset=COCORetrieval(data_root='coco2014/')
>>> train_dataset
Dataset COCORetrieval
Number of samples: 414113
Annotation file: /coco2014/annotations/captions_train2014.json
Prefix of images: /coco2014/
>>> from mmpretrain.datasets import COCORetrieval
>>> val_dataset = COCORetrieval(data_root='coco2014/')
>>> val_dataset
Dataset COCORetrieval
Number of samples: 202654
Annotation file: /coco2014/annotations/captions_val2014.json
Prefix of images: /coco2014/
"""
def __init__(self,
ann_file: str,
test_mode: bool = False,
data_prefix: Union[str, dict] = '',
data_root: str = '',
pipeline: Sequence = (),
**kwargs):
if isinstance(data_prefix, str):
data_prefix = dict(img_path=expanduser(data_prefix))
ann_file = expanduser(ann_file)
transforms = []
for transform in pipeline:
if isinstance(transform, dict):
transforms.append(TRANSFORMS.build(transform))
else:
transforms.append(transform)
super().__init__(
data_root=data_root,
data_prefix=data_prefix,
test_mode=test_mode,
pipeline=transforms,
ann_file=ann_file,
**kwargs,
)
def load_data_list(self) -> List[dict]:
"""Load data list."""
# get file backend
img_prefix = self.data_prefix['img_path']
file_backend = get_file_backend(img_prefix)
anno_info = json.load(open(self.ann_file, 'r'))
# mapping img_id to img filename
img_dict = OrderedDict()
for idx, img in enumerate(anno_info['images']):
if img['id'] not in img_dict:
img_rel_path = img['coco_url'].rsplit('/', 2)[-2:]
img_path = file_backend.join_path(img_prefix, *img_rel_path)
# create new idx for image
img_dict[img['id']] = dict(
ori_id=img['id'],
image_id=idx, # will be used for evaluation
img_path=img_path,
text=[],
gt_text_id=[],
gt_image_id=[],
)
train_list = []
for idx, anno in enumerate(anno_info['annotations']):
anno['text'] = anno.pop('caption')
anno['ori_id'] = anno.pop('id')
anno['text_id'] = idx # will be used for evaluation
# 1. prepare train data list item
train_data = anno.copy()
train_image = img_dict[train_data['image_id']]
train_data['img_path'] = train_image['img_path']
train_data['image_ori_id'] = train_image['ori_id']
train_data['image_id'] = train_image['image_id']
train_data['is_matched'] = True
train_list.append(train_data)
# 2. prepare eval data list item based on img dict
img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id'])
img_dict[anno['image_id']]['text'].append(anno['text'])
img_dict[anno['image_id']]['gt_image_id'].append(
train_image['image_id'])
self.img_size = len(img_dict)
self.text_size = len(anno_info['annotations'])
# return needed format data list
if self.test_mode:
return list(img_dict.values())
return train_list
|