|
import os |
|
import torch |
|
import torch.utils.data as data |
|
import numpy as np |
|
|
|
from PIL import Image |
|
import h5py |
|
|
|
__all__ = ['ImagenetResults'] |
|
|
|
|
|
class Imagenet_Segmentation(data.Dataset): |
|
CLASSES = 2 |
|
|
|
def __init__(self, |
|
path, |
|
transform=None, |
|
target_transform=None): |
|
self.path = path |
|
self.transform = transform |
|
self.target_transform = target_transform |
|
self.h5py = None |
|
tmp = h5py.File(path, 'r') |
|
self.data_length = len(tmp['/value/img']) |
|
tmp.close() |
|
del tmp |
|
|
|
def __getitem__(self, index): |
|
|
|
if self.h5py is None: |
|
self.h5py = h5py.File(self.path, 'r') |
|
|
|
img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) |
|
target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) |
|
|
|
img = Image.fromarray(img).convert('RGB') |
|
target = Image.fromarray(target) |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
|
|
if self.target_transform is not None: |
|
target = np.array(self.target_transform(target)).astype('int32') |
|
target = torch.from_numpy(target).long() |
|
|
|
return img, target |
|
|
|
def __len__(self): |
|
return self.data_length |
|
|
|
|
|
class ImagenetResults(data.Dataset): |
|
def __init__(self, path): |
|
super(ImagenetResults, self).__init__() |
|
|
|
self.path = os.path.join(path, 'results.hdf5') |
|
self.data = None |
|
|
|
print('Reading dataset length...') |
|
with h5py.File(self.path, 'r') as f: |
|
self.data_length = len(f['/image']) |
|
|
|
def __len__(self): |
|
return self.data_length |
|
|
|
def __getitem__(self, item): |
|
if self.data is None: |
|
self.data = h5py.File(self.path, 'r') |
|
|
|
image = torch.tensor(self.data['image'][item]) |
|
vis = torch.tensor(self.data['vis'][item]) |
|
target = torch.tensor(self.data['target'][item]).long() |
|
|
|
return image, vis, target |
|
|