import os import tarfile import torch import torch.utils.data as data import numpy as np import h5py from PIL import Image from scipy import io from torchvision.datasets.utils import download_url DATASET_YEAR_DICT = { '2012': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 'filename': 'VOCtrainval_11-May-2012.tar', 'md5': '6cd6e144f989b92b3379bac3b3de84fd', 'base_dir': 'VOCdevkit/VOC2012' }, '2011': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar', 'filename': 'VOCtrainval_25-May-2011.tar', 'md5': '6c3384ef61512963050cb5d687e5bf1e', 'base_dir': 'TrainVal/VOCdevkit/VOC2011' }, '2010': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar', 'filename': 'VOCtrainval_03-May-2010.tar', 'md5': 'da459979d0c395079b5c75ee67908abb', 'base_dir': 'VOCdevkit/VOC2010' }, '2009': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar', 'filename': 'VOCtrainval_11-May-2009.tar', 'md5': '59065e4b188729180974ef6572f6a212', 'base_dir': 'VOCdevkit/VOC2009' }, '2008': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar', 'filename': 'VOCtrainval_11-May-2012.tar', 'md5': '2629fa636546599198acfcfbfcf1904a', 'base_dir': 'VOCdevkit/VOC2008' }, '2007': { 'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar', 'filename': 'VOCtrainval_06-Nov-2007.tar', 'md5': 'c52e279531787c972589f7e41ab4ae64', 'base_dir': 'VOCdevkit/VOC2007' } } class VOCSegmentation(data.Dataset): """`Pascal VOC `_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ CLASSES = 20 CLASSES_NAMES = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 'tvmonitor', 'ambigious' ] def __init__(self, root, year='2012', image_set='train', download=False, transform=None, target_transform=None): self.root = os.path.expanduser(root) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] self.transform = transform self.target_transform = target_transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages') mask_dir = os.path.join(voc_root, 'SegmentationClass') if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"') with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] assert (len(self.images) == len(self.masks)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert('RGB') target = Image.open(self.masks[index]) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = np.array(self.target_transform(target)).astype('int32') target[target == 255] = -1 target = torch.from_numpy(target).long() return img, target @staticmethod def _mask_transform(mask): target = np.array(mask).astype('int32') target[target == 255] = -1 return torch.from_numpy(target).long() def __len__(self): return len(self.images) @property def pred_offset(self): return 0 class VOCClassification(data.Dataset): """`Pascal VOC `_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ CLASSES = 20 def __init__(self, root, year='2012', image_set='train', download=False, transform=None): self.root = os.path.expanduser(root) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] self.transform = transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages') mask_dir = os.path.join(voc_root, 'SegmentationClass') if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"') with open(os.path.join(split_f), "r") as f: file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names] assert (len(self.images) == len(self.masks)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert('RGB') target = Image.open(self.masks[index]) # if self.transform is not None: # img = self.transform(img) if self.transform is not None: img, target = self.transform(img, target) visible_classes = np.unique(target) labels = torch.zeros(self.CLASSES) for id in visible_classes: if id not in (0, 255): labels[id - 1].fill_(1) return img, labels def __len__(self): return len(self.images) class VOCSBDClassification(data.Dataset): """`Pascal VOC `_ Segmentation Dataset. Args: root (string): Root directory of the VOC Dataset. year (string, optional): The dataset year, supports years 2007 to 2012. image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ CLASSES = 20 def __init__(self, root, sbd_root, year='2012', image_set='train', download=False, transform=None): self.root = os.path.expanduser(root) self.sbd_root = os.path.expanduser(sbd_root) self.year = year self.url = DATASET_YEAR_DICT[year]['url'] self.filename = DATASET_YEAR_DICT[year]['filename'] self.md5 = DATASET_YEAR_DICT[year]['md5'] self.transform = transform self.image_set = image_set base_dir = DATASET_YEAR_DICT[year]['base_dir'] voc_root = os.path.join(self.root, base_dir) image_dir = os.path.join(voc_root, 'JPEGImages') mask_dir = os.path.join(voc_root, 'SegmentationClass') sbd_image_dir = os.path.join(sbd_root, 'img') sbd_mask_dir = os.path.join(sbd_root, 'cls') if download: download_extract(self.url, self.root, self.filename, self.md5) if not os.path.isdir(voc_root): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation') split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt') sbd_split = os.path.join(sbd_root, 'train.txt') if not os.path.exists(split_f): raise ValueError( 'Wrong image_set entered! Please use image_set="train" ' 'or image_set="trainval" or image_set="val"') with open(os.path.join(split_f), "r") as f: voc_file_names = [x.strip() for x in f.readlines()] with open(os.path.join(sbd_split), "r") as f: sbd_file_names = [x.strip() for x in f.readlines()] self.images = [os.path.join(image_dir, x + ".jpg") for x in voc_file_names] self.images += [os.path.join(sbd_image_dir, x + ".jpg") for x in sbd_file_names] self.masks = [os.path.join(mask_dir, x + ".png") for x in voc_file_names] self.masks += [os.path.join(sbd_mask_dir, x + ".mat") for x in sbd_file_names] assert (len(self.images) == len(self.masks)) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is the image segmentation. """ img = Image.open(self.images[index]).convert('RGB') mask_path = self.masks[index] if mask_path[-3:] == 'mat': target = io.loadmat(mask_path, struct_as_record=False, squeeze_me=True)['GTcls'].Segmentation target = Image.fromarray(target, mode='P') else: target = Image.open(self.masks[index]) if self.transform is not None: img, target = self.transform(img, target) visible_classes = np.unique(target) labels = torch.zeros(self.CLASSES) for id in visible_classes: if id not in (0, 255): labels[id - 1].fill_(1) return img, labels def __len__(self): return len(self.images) def download_extract(url, root, filename, md5): download_url(url, root, filename, md5) with tarfile.open(os.path.join(root, filename), "r") as tar: tar.extractall(path=root) class VOCResults(data.Dataset): CLASSES = 20 CLASSES_NAMES = [ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 'tvmonitor', 'ambigious' ] def __init__(self, path): super(VOCResults, self).__init__() self.path = os.path.join(path, 'results.hdf5') self.data = None print('Reading dataset length...') with h5py.File(self.path , 'r') as f: self.data_length = len(f['/image']) def __len__(self): return self.data_length def __getitem__(self, item): if self.data is None: self.data = h5py.File(self.path, 'r') image = torch.tensor(self.data['image'][item]) vis = torch.tensor(self.data['vis'][item]) target = torch.tensor(self.data['target'][item]) class_pred = torch.tensor(self.data['class_pred'][item]) return image, vis, target, class_pred