Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# Copyright (c) Meta Platforms, Inc. All Rights Reserved | |
import itertools | |
import logging | |
import numpy as np | |
from collections import Counter | |
import torch.utils.data | |
from tabulate import tabulate | |
from termcolor import colored | |
from detectron2.utils.logger import _log_api_usage, log_first_n | |
from detectron2.data.catalog import DatasetCatalog, MetadataCatalog | |
import torch.utils.data | |
from detectron2.config import configurable | |
from detectron2.data.build import ( | |
build_batch_data_loader, | |
trivial_batch_collator, | |
load_proposals_into_dataset, | |
filter_images_with_only_crowd_annotations, | |
filter_images_with_few_keypoints, | |
print_instances_class_histogram, | |
) | |
from detectron2.data.common import DatasetFromList, MapDataset | |
from detectron2.data.dataset_mapper import DatasetMapper | |
from detectron2.data.detection_utils import check_metadata_consistency | |
from detectron2.data.samplers import ( | |
InferenceSampler, | |
RandomSubsetTrainingSampler, | |
RepeatFactorTrainingSampler, | |
TrainingSampler, | |
) | |
""" | |
This file contains the default logic to build a dataloader for training or testing. | |
""" | |
__all__ = [ | |
"build_detection_train_loader", | |
"build_detection_test_loader", | |
] | |
def print_classification_instances_class_histogram(dataset_dicts, class_names): | |
""" | |
Args: | |
dataset_dicts (list[dict]): list of dataset dicts. | |
class_names (list[str]): list of class names (zero-indexed). | |
""" | |
num_classes = len(class_names) | |
hist_bins = np.arange(num_classes + 1) | |
histogram = np.zeros((num_classes,), dtype=np.int) | |
for entry in dataset_dicts: | |
classes = np.asarray([entry["category_id"]], dtype=np.int) | |
if len(classes): | |
assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" | |
assert ( | |
classes.max() < num_classes | |
), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" | |
histogram += np.histogram(classes, bins=hist_bins)[0] | |
N_COLS = min(6, len(class_names) * 2) | |
def short_name(x): | |
# make long class names shorter. useful for lvis | |
if len(x) > 13: | |
return x[:11] + ".." | |
return x | |
data = list( | |
itertools.chain( | |
*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)] | |
) | |
) | |
total_num_instances = sum(data[1::2]) | |
data.extend([None] * (N_COLS - (len(data) % N_COLS))) | |
if num_classes > 1: | |
data.extend(["total", total_num_instances]) | |
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) | |
table = tabulate( | |
data, | |
headers=["category", "#instances"] * (N_COLS // 2), | |
tablefmt="pipe", | |
numalign="left", | |
stralign="center", | |
) | |
log_first_n( | |
logging.INFO, | |
"Distribution of instances among all {} categories:\n".format(num_classes) | |
+ colored(table, "cyan"), | |
key="message", | |
) | |
def wrap_metas(dataset_dict, **kwargs): | |
def _assign_attr(data_dict: dict, **kwargs): | |
assert not any( | |
[key in data_dict for key in kwargs] | |
), "Assigned attributes should not exist in the original sample." | |
data_dict.update(kwargs) | |
return data_dict | |
return [_assign_attr(sample, meta=kwargs) for sample in dataset_dict] | |
def get_detection_dataset_dicts( | |
names, filter_empty=True, min_keypoints=0, proposal_files=None | |
): | |
""" | |
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. | |
Args: | |
names (str or list[str]): a dataset name or a list of dataset names | |
filter_empty (bool): whether to filter out images without instance annotations | |
min_keypoints (int): filter out images with fewer keypoints than | |
`min_keypoints`. Set to 0 to do nothing. | |
proposal_files (list[str]): if given, a list of object proposal files | |
that match each dataset in `names`. | |
Returns: | |
list[dict]: a list of dicts following the standard dataset dict format. | |
""" | |
if isinstance(names, str): | |
names = [names] | |
assert len(names), names | |
dataset_dicts = [ | |
wrap_metas(DatasetCatalog.get(dataset_name), dataset_name=dataset_name) | |
for dataset_name in names | |
] | |
for dataset_name, dicts in zip(names, dataset_dicts): | |
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) | |
if proposal_files is not None: | |
assert len(names) == len(proposal_files) | |
# load precomputed proposals from proposal files | |
dataset_dicts = [ | |
load_proposals_into_dataset(dataset_i_dicts, proposal_file) | |
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) | |
] | |
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) | |
has_instances = "annotations" in dataset_dicts[0] | |
if filter_empty and has_instances: | |
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) | |
if min_keypoints > 0 and has_instances: | |
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) | |
if has_instances: | |
try: | |
class_names = MetadataCatalog.get(names[0]).thing_classes | |
check_metadata_consistency("thing_classes", names) | |
print_instances_class_histogram(dataset_dicts, class_names) | |
except AttributeError: # class names are not available for this dataset | |
pass | |
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) | |
return dataset_dicts | |
def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): | |
if dataset is None: | |
dataset = get_detection_dataset_dicts( | |
cfg.DATASETS.TRAIN, | |
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, | |
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE | |
if cfg.MODEL.KEYPOINT_ON | |
else 0, | |
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN | |
if cfg.MODEL.LOAD_PROPOSALS | |
else None, | |
) | |
_log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) | |
if mapper is None: | |
mapper = DatasetMapper(cfg, True) | |
if sampler is None: | |
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN | |
logger = logging.getLogger(__name__) | |
logger.info("Using training sampler {}".format(sampler_name)) | |
if sampler_name == "TrainingSampler": | |
sampler = TrainingSampler(len(dataset)) | |
elif sampler_name == "RepeatFactorTrainingSampler": | |
repeat_factors = ( | |
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( | |
dataset, cfg.DATALOADER.REPEAT_THRESHOLD | |
) | |
) | |
sampler = RepeatFactorTrainingSampler(repeat_factors) | |
elif sampler_name == "RandomSubsetTrainingSampler": | |
sampler = RandomSubsetTrainingSampler( | |
len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO | |
) | |
else: | |
raise ValueError("Unknown training sampler: {}".format(sampler_name)) | |
return { | |
"dataset": dataset, | |
"sampler": sampler, | |
"mapper": mapper, | |
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH, | |
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, | |
"num_workers": cfg.DATALOADER.NUM_WORKERS, | |
} | |
# TODO can allow dataset as an iterable or IterableDataset to make this function more general | |
def build_detection_train_loader( | |
dataset, | |
*, | |
mapper, | |
sampler=None, | |
total_batch_size, | |
aspect_ratio_grouping=True, | |
num_workers=0, | |
): | |
""" | |
Build a dataloader for object detection with some default features. | |
This interface is experimental. | |
Args: | |
dataset (list or torch.utils.data.Dataset): a list of dataset dicts, | |
or a map-style pytorch dataset. They can be obtained by using | |
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. | |
mapper (callable): a callable which takes a sample (dict) from dataset and | |
returns the format to be consumed by the model. | |
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. | |
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces | |
indices to be applied on ``dataset``. Default to :class:`TrainingSampler`, | |
which coordinates an infinite random shuffle sequence across all workers. | |
total_batch_size (int): total batch size across all workers. Batching | |
simply puts data into a list. | |
aspect_ratio_grouping (bool): whether to group images with similar | |
aspect ratio for efficiency. When enabled, it requires each | |
element in dataset be a dict with keys "width" and "height". | |
num_workers (int): number of parallel data loading workers | |
Returns: | |
torch.utils.data.DataLoader: | |
a dataloader. Each output from it is a ``list[mapped_element]`` of length | |
``total_batch_size / num_workers``, where ``mapped_element`` is produced | |
by the ``mapper``. | |
""" | |
if isinstance(dataset, list): | |
dataset = DatasetFromList(dataset, copy=False) | |
if mapper is not None: | |
dataset = MapDataset(dataset, mapper) | |
if sampler is None: | |
sampler = TrainingSampler(len(dataset)) | |
assert isinstance(sampler, torch.utils.data.sampler.Sampler) | |
return build_batch_data_loader( | |
dataset, | |
sampler, | |
total_batch_size, | |
aspect_ratio_grouping=aspect_ratio_grouping, | |
num_workers=num_workers, | |
) | |
def _test_loader_from_config(cfg, dataset_name, mapper=None): | |
""" | |
Uses the given `dataset_name` argument (instead of the names in cfg), because the | |
standard practice is to evaluate each test set individually (not combining them). | |
""" | |
if isinstance(dataset_name, str): | |
dataset_name = [dataset_name] | |
dataset = get_detection_dataset_dicts( | |
dataset_name, | |
filter_empty=False, | |
proposal_files=[ | |
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] | |
for x in dataset_name | |
] | |
if cfg.MODEL.LOAD_PROPOSALS | |
else None, | |
) | |
if mapper is None: | |
mapper = DatasetMapper(cfg, False) | |
return { | |
"dataset": dataset, | |
"mapper": mapper, | |
"num_workers": 0, | |
"samples_per_gpu": cfg.SOLVER.TEST_IMS_PER_BATCH, | |
} | |
def build_detection_test_loader( | |
dataset, *, mapper, sampler=None, num_workers=0, samples_per_gpu=1 | |
): | |
""" | |
Similar to `build_detection_train_loader`, but uses a batch size of 1, | |
and :class:`InferenceSampler`. This sampler coordinates all workers to | |
produce the exact set of all samples. | |
This interface is experimental. | |
Args: | |
dataset (list or torch.utils.data.Dataset): a list of dataset dicts, | |
or a map-style pytorch dataset. They can be obtained by using | |
:func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. | |
mapper (callable): a callable which takes a sample (dict) from dataset | |
and returns the format to be consumed by the model. | |
When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. | |
sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces | |
indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, | |
which splits the dataset across all workers. | |
num_workers (int): number of parallel data loading workers | |
Returns: | |
DataLoader: a torch DataLoader, that loads the given detection | |
dataset, with test-time transformation and batching. | |
Examples: | |
:: | |
data_loader = build_detection_test_loader( | |
DatasetRegistry.get("my_test"), | |
mapper=DatasetMapper(...)) | |
# or, instantiate with a CfgNode: | |
data_loader = build_detection_test_loader(cfg, "my_test") | |
""" | |
if isinstance(dataset, list): | |
dataset = DatasetFromList(dataset, copy=False) | |
if mapper is not None: | |
dataset = MapDataset(dataset, mapper) | |
if sampler is None: | |
sampler = InferenceSampler(len(dataset)) | |
# Always use 1 image per worker during inference since this is the | |
# standard when reporting inference time in papers. | |
batch_sampler = torch.utils.data.sampler.BatchSampler( | |
sampler, samples_per_gpu, drop_last=False | |
) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
num_workers=num_workers, | |
batch_sampler=batch_sampler, | |
collate_fn=trivial_batch_collator, | |
) | |
return data_loader | |