|
|
|
|
|
import os |
|
import pickle |
|
import sys |
|
import unittest |
|
from functools import partial |
|
import torch |
|
from iopath.common.file_io import LazyPath |
|
|
|
from detectron2 import model_zoo |
|
from detectron2.config import get_cfg, instantiate |
|
from detectron2.data import ( |
|
DatasetCatalog, |
|
DatasetFromList, |
|
MapDataset, |
|
ToIterableDataset, |
|
build_batch_data_loader, |
|
build_detection_test_loader, |
|
build_detection_train_loader, |
|
) |
|
from detectron2.data.common import ( |
|
AspectRatioGroupedDataset, |
|
set_default_dataset_from_list_serialize_method, |
|
) |
|
from detectron2.data.samplers import InferenceSampler, TrainingSampler |
|
|
|
|
|
def _a_slow_func(x): |
|
return "path/{}".format(x) |
|
|
|
|
|
class TestDatasetFromList(unittest.TestCase): |
|
|
|
@unittest.skipIf(sys.version_info.minor <= 6, "Not supported in Python 3.6") |
|
def test_using_lazy_path(self): |
|
dataset = [] |
|
for i in range(10): |
|
dataset.append({"file_name": LazyPath(partial(_a_slow_func, i))}) |
|
|
|
dataset = DatasetFromList(dataset) |
|
for i in range(10): |
|
path = dataset[i]["file_name"] |
|
self.assertTrue(isinstance(path, LazyPath)) |
|
self.assertEqual(os.fspath(path), _a_slow_func(i)) |
|
|
|
def test_alternative_serialize_method(self): |
|
dataset = [1, 2, 3] |
|
dataset = DatasetFromList(dataset, serialize=torch.tensor) |
|
self.assertEqual(dataset[2], torch.tensor(3)) |
|
|
|
def test_change_default_serialize_method(self): |
|
dataset = [1, 2, 3] |
|
with set_default_dataset_from_list_serialize_method(torch.tensor): |
|
dataset_1 = DatasetFromList(dataset, serialize=True) |
|
self.assertEqual(dataset_1[2], torch.tensor(3)) |
|
dataset_2 = DatasetFromList(dataset, serialize=True) |
|
self.assertEqual(dataset_2[2], 3) |
|
|
|
|
|
class TestMapDataset(unittest.TestCase): |
|
@staticmethod |
|
def map_func(x): |
|
if x == 2: |
|
return None |
|
return x * 2 |
|
|
|
def test_map_style(self): |
|
ds = DatasetFromList([1, 2, 3]) |
|
ds = MapDataset(ds, TestMapDataset.map_func) |
|
self.assertEqual(ds[0], 2) |
|
self.assertEqual(ds[2], 6) |
|
self.assertIn(ds[1], [2, 6]) |
|
|
|
def test_iter_style(self): |
|
class DS(torch.utils.data.IterableDataset): |
|
def __iter__(self): |
|
yield from [1, 2, 3] |
|
|
|
ds = DS() |
|
ds = MapDataset(ds, TestMapDataset.map_func) |
|
self.assertIsInstance(ds, torch.utils.data.IterableDataset) |
|
|
|
data = list(iter(ds)) |
|
self.assertEqual(data, [2, 6]) |
|
|
|
def test_pickleability(self): |
|
ds = DatasetFromList([1, 2, 3]) |
|
ds = MapDataset(ds, lambda x: x * 2) |
|
ds = pickle.loads(pickle.dumps(ds)) |
|
self.assertEqual(ds[0], 2) |
|
|
|
|
|
class TestAspectRatioGrouping(unittest.TestCase): |
|
def test_reiter_leak(self): |
|
data = [(1, 0), (0, 1), (1, 0), (0, 1)] |
|
data = [{"width": a, "height": b} for (a, b) in data] |
|
batchsize = 2 |
|
dataset = AspectRatioGroupedDataset(data, batchsize) |
|
|
|
for _ in range(5): |
|
for idx, __ in enumerate(dataset): |
|
if idx == 1: |
|
|
|
break |
|
|
|
for bucket in dataset._buckets: |
|
self.assertLess(len(bucket), batchsize) |
|
|
|
|
|
class _MyData(torch.utils.data.IterableDataset): |
|
def __iter__(self): |
|
while True: |
|
yield 1 |
|
|
|
|
|
class TestDataLoader(unittest.TestCase): |
|
def _get_kwargs(self): |
|
|
|
cfg = model_zoo.get_config("common/data/coco.py").dataloader.train |
|
cfg.dataset.names = "coco_2017_val_100" |
|
cfg.pop("_target_") |
|
kwargs = {k: instantiate(v) for k, v in cfg.items()} |
|
return kwargs |
|
|
|
def test_build_dataloader_train(self): |
|
kwargs = self._get_kwargs() |
|
dl = build_detection_train_loader(**kwargs) |
|
next(iter(dl)) |
|
|
|
def test_build_iterable_dataloader_train(self): |
|
kwargs = self._get_kwargs() |
|
ds = DatasetFromList(kwargs.pop("dataset")) |
|
ds = ToIterableDataset(ds, TrainingSampler(len(ds))) |
|
dl = build_detection_train_loader(dataset=ds, **kwargs) |
|
next(iter(dl)) |
|
|
|
def test_build_iterable_dataloader_from_cfg(self): |
|
cfg = get_cfg() |
|
cfg.DATASETS.TRAIN = ["iter_data"] |
|
DatasetCatalog.register("iter_data", lambda: _MyData()) |
|
dl = build_detection_train_loader(cfg, mapper=lambda x: x, aspect_ratio_grouping=False) |
|
next(iter(dl)) |
|
|
|
dl = build_detection_test_loader(cfg, "iter_data", mapper=lambda x: x) |
|
next(iter(dl)) |
|
|
|
def _check_is_range(self, data_loader, N): |
|
|
|
data = list(iter(data_loader)) |
|
data = [x for batch in data for x in batch] |
|
self.assertEqual(len(data), N) |
|
self.assertEqual(set(data), set(range(N))) |
|
|
|
def test_build_batch_dataloader_inference(self): |
|
|
|
N = 96 |
|
ds = DatasetFromList(list(range(N))) |
|
sampler = InferenceSampler(len(ds)) |
|
dl = build_batch_data_loader(ds, sampler, 8, num_workers=3) |
|
self._check_is_range(dl, N) |
|
|
|
def test_build_batch_dataloader_inference_incomplete_batch(self): |
|
|
|
|
|
def _test(N, batch_size, num_workers): |
|
ds = DatasetFromList(list(range(N))) |
|
sampler = InferenceSampler(len(ds)) |
|
|
|
dl = build_batch_data_loader(ds, sampler, batch_size, num_workers=num_workers) |
|
data = list(iter(dl)) |
|
self.assertEqual(len(data), len(dl)) |
|
self._check_is_range(dl, N // batch_size * batch_size) |
|
|
|
dl = build_batch_data_loader( |
|
ds, sampler, batch_size, num_workers=num_workers, drop_last=False |
|
) |
|
data = list(iter(dl)) |
|
self.assertEqual(len(data), len(dl)) |
|
self._check_is_range(dl, N) |
|
|
|
_test(48, batch_size=8, num_workers=3) |
|
_test(47, batch_size=8, num_workers=3) |
|
_test(46, batch_size=8, num_workers=3) |
|
_test(40, batch_size=8, num_workers=3) |
|
_test(39, batch_size=8, num_workers=3) |
|
|
|
def test_build_dataloader_inference(self): |
|
N = 50 |
|
ds = DatasetFromList(list(range(N))) |
|
sampler = InferenceSampler(len(ds)) |
|
|
|
dl = build_detection_test_loader( |
|
dataset=ds, sampler=sampler, mapper=lambda x: x, num_workers=3 |
|
) |
|
self._check_is_range(dl, N) |
|
|
|
|
|
dl = build_detection_test_loader( |
|
dataset=ds, sampler=sampler, mapper=lambda x: x, batch_size=4, num_workers=0 |
|
) |
|
self._check_is_range(dl, N) |
|
|
|
def test_build_iterable_dataloader_inference(self): |
|
|
|
N = 50 |
|
ds = DatasetFromList(list(range(N))) |
|
ds = ToIterableDataset(ds, InferenceSampler(len(ds))) |
|
dl = build_detection_test_loader(dataset=ds, mapper=lambda x: x, num_workers=3) |
|
self._check_is_range(dl, N) |
|
|