import copy import itertools import functools import numpy as np import torch import torch.utils.data import torchvision.transforms as torch_transforms import encoding.datasets as enc_ds encoding_datasets = { x: functools.partial(enc_ds.get_dataset, x) for x in ["coco", "ade20k", "pascal_voc", "pascal_aug", "pcontext", "citys"] } def get_dataset(name, **kwargs): if name in encoding_datasets: return encoding_datasets[name.lower()](**kwargs) assert False, f"dataset {name} not found" def get_available_datasets(): return list(encoding_datasets.keys())