import json |
from torch.utils import data |
from torchvision.datasets import ImageFolder |
import torch |
import os |
from PIL import Image |
import numpy as np |
import argparse |
from tqdm import tqdm |
from munkres import Munkres |
import multiprocessing |
from multiprocessing import Process, Manager |
import collections |
import torchvision.transforms as transforms |
import torchvision.transforms.functional as TF |
import random |
import torchvision |
import cv2 |
import random |
torch.manual_seed(0) |
SegItem = collections.namedtuple('SegItem', ('image_name', 'tag')) |
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], |
std=[0.5, 0.5, 0.5]) |
TRANSFORM_TRAIN = transforms.Compose([ |
transforms.RandomResizedCrop(224), |
transforms.RandomHorizontalFlip(), |
]) |
TRANSFORM_EVAL = transforms.Compose([ |
transforms.Resize(256), |
transforms.CenterCrop(224), |
]) |
IMAGE_TRANSFORMS = transforms.Compose([ |
transforms.ToTensor(), |
normalize |
]) |
MERGED_TAGS = {'n04356056', 'n04355933', |
'n04493381', 'n02808440', |
'n03642806', 'n03832673', |
'n04008634', 'n03773504', |
'n03887697', 'n15075141'} |
class SegmentationDataset(ImageFolder): |
def __init__(self, seg_path, imagenet_path, partition=TRAIN_PARTITION, num_samples=2, train_classes=500 |
, imagenet_classes_path='imagenet_classes.json', seed=None): |
assert partition in LEGAL_PARTITIONS |
self._partition = partition |
self._seg_path = seg_path |
self._imagenet_path = imagenet_path |
with open(imagenet_classes_path, 'r') as f: |
self._imagenet_classes = json.load(f) |
self._tag_list = [tag for tag in os.listdir(self._seg_path) if tag not in MERGED_TAGS] |
if seed: |
print(f'Shuffling training classes with seed {seed}') |
random.seed(seed) |
random.shuffle(self._tag_list) |
if partition == TRAIN_PARTITION: |
self._tag_list = self._tag_list[:train_classes] |
elif partition == VAL_PARTITION: |
self._tag_list = self._tag_list[train_classes:] |
for tag in self._tag_list: |
assert tag in self._imagenet_classes |
self._all_segementations = [] |
for tag in self._tag_list: |
base_dir = os.path.join(self._seg_path, tag) |
for i, seg in enumerate(os.listdir(base_dir)): |
if i >= num_samples: |
break |
self._all_segementations.append(SegItem(seg.split('.')[0], tag)) |
def __getitem__(self, item): |
seg_item = self._all_segementations[item] |
seg_path = os.path.join(self._seg_path, seg_item.tag, seg_item.image_name + ".png") |
image_path = os.path.join(self._imagenet_path, seg_item.tag, seg_item.image_name + ".JPEG") |
seg_map = Image.open(seg_path) |
image = Image.open(image_path) |
image = image.convert('RGB') |
seg_map = np.array(seg_map) |
seg_map = seg_map[:, :, 1] * 256 + seg_map[:, :, 0] |
assert len([cand for cand in np.unique(seg_map) if cand != 0 and cand != 1000]) == 1 |
seg_map[seg_map == 1000] = 0 |
seg_map[seg_map != 0] = 1 |
seg_map = torch.from_numpy(seg_map.astype(np.float32)) |
seg_map = seg_map.reshape(1, seg_map.shape[-2], seg_map.shape[-1]) |
if self._partition == VAL_PARTITION: |
image = TRANSFORM_EVAL(image) |
seg_map = TRANSFORM_EVAL(seg_map) |
elif self._partition == TRAIN_PARTITION: |
resize = transforms.Resize(size=(256, 256)) |
image = resize(image) |
seg_map = resize(seg_map) |
i, j, h, w = transforms.RandomCrop.get_params( |
image, output_size=(224, 224)) |
image = TF.crop(image, i, j, h, w) |
seg_map = TF.crop(seg_map, i, j, h, w) |
if random.random() > 0.5: |
image = TF.hflip(image) |
seg_map = TF.hflip(seg_map) |
else: |
raise Exception(f"Unsupported partition type {self._partition}") |
image_ten = IMAGE_TRANSFORMS(image) |
class_name = int(self._imagenet_classes[seg_item.tag]) |
return seg_map, image_ten, class_name |
def __len__(self): |
return len(self._all_segementations) |