import os import os.path import numpy as np import torch import torch.utils.data as data from PIL import Image class ParsingGenerationDeepFashionAttrSegmDataset(data.Dataset): def __init__(self, segm_dir, pose_dir, ann_file, downsample_factor=2): self._densepose_path = pose_dir self._segm_path = segm_dir self._image_fnames = [] self.attrs = [] self.downsample_factor = downsample_factor # training, ground-truth available assert os.path.exists(ann_file) for row in open(os.path.join(ann_file), 'r'): annotations = row.split() self._image_fnames.append(annotations[0]) self.attrs.append([int(i) for i in annotations[1:]]) def _open_file(self, path_prefix, fname): return open(os.path.join(path_prefix, fname), 'rb') def _load_densepose(self, raw_idx): fname = self._image_fnames[raw_idx] fname = f'{fname[:-4]}_densepose.png' with self._open_file(self._densepose_path, fname) as f: densepose = Image.open(f) if self.downsample_factor != 1: width, height = densepose.size width = width // self.downsample_factor height = height // self.downsample_factor densepose = densepose.resize( size=(width, height), resample=Image.NEAREST) # channel-wise IUV order, [3, H, W] densepose = np.array(densepose)[:, :, 2:].transpose(2, 0, 1) return densepose.astype(np.float32) def _load_segm(self, raw_idx): fname = self._image_fnames[raw_idx] fname = f'{fname[:-4]}_segm.png' with self._open_file(self._segm_path, fname) as f: segm = Image.open(f) if self.downsample_factor != 1: width, height = segm.size width = width // self.downsample_factor height = height // self.downsample_factor segm = segm.resize( size=(width, height), resample=Image.NEAREST) segm = np.array(segm) return segm.astype(np.float32) def __getitem__(self, index): pose = self._load_densepose(index) segm = self._load_segm(index) attr = self.attrs[index] pose = torch.from_numpy(pose) segm = torch.LongTensor(segm) attr = torch.LongTensor(attr) pose = pose / 12. - 1 return_dict = { 'densepose': pose, 'segm': segm, 'attr': attr, 'img_name': self._image_fnames[index] } return return_dict def __len__(self): return len(self._image_fnames)