Hila's picture
init commit
7754b29
raw
history blame
2.03 kB
import os
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
import h5py
__all__ = ['ImagenetResults']
class Imagenet_Segmentation(data.Dataset):
CLASSES = 2
def __init__(self,
path,
transform=None,
target_transform=None):
self.path = path
self.transform = transform
self.target_transform = target_transform
self.h5py = None
tmp = h5py.File(path, 'r')
self.data_length = len(tmp['/value/img'])
tmp.close()
del tmp
def __getitem__(self, index):
if self.h5py is None:
self.h5py = h5py.File(self.path, 'r')
img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
img = Image.fromarray(img).convert('RGB')
target = Image.fromarray(target)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = np.array(self.target_transform(target)).astype('int32')
target = torch.from_numpy(target).long()
return img, target
def __len__(self):
return self.data_length
class ImagenetResults(data.Dataset):
def __init__(self, path):
super(ImagenetResults, self).__init__()
self.path = os.path.join(path, 'results.hdf5')
self.data = None
print('Reading dataset length...')
with h5py.File(self.path, 'r') as f:
self.data_length = len(f['/image'])
def __len__(self):
return self.data_length
def __getitem__(self, item):
if self.data is None:
self.data = h5py.File(self.path, 'r')
image = torch.tensor(self.data['image'][item])
vis = torch.tensor(self.data['vis'][item])
target = torch.tensor(self.data['target'][item]).long()
return image, vis, target