import glob import json import os import random import cv2 import numpy as np import torch import torch.nn.functional as F from PIL import Image from pycocotools.coco import COCO from model.segment_anything.utils.transforms import ResizeLongestSide from torchvision import transforms def init_mapillary(base_image_dir): mapillary_data_root = os.path.join(base_image_dir, "mapillary") with open(os.path.join(mapillary_data_root, "config_v2.0.json")) as f: mapillary_classes = json.load(f)["labels"] mapillary_classes = [x["readable"].lower() for x in mapillary_classes] mapillary_classes = np.array(mapillary_classes) mapillary_labels = sorted( glob.glob( os.path.join(mapillary_data_root, "training", "v2.0", "labels", "*.png") ) ) mapillary_images = [ x.replace(".png", ".jpg").replace("v2.0/labels", "images") for x in mapillary_labels ] print("mapillary: ", len(mapillary_images)) return mapillary_classes, mapillary_images, mapillary_labels def init_ade20k(base_image_dir): with open("utils/ade20k_classes.json", "r") as f: ade20k_classes = json.load(f) ade20k_classes = np.array(ade20k_classes) image_ids = sorted( os.listdir(os.path.join(base_image_dir, "ade20k/images", "training")) ) ade20k_image_ids = [] for x in image_ids: if x.endswith(".jpg"): ade20k_image_ids.append(x[:-4]) ade20k_images = [] for image_id in ade20k_image_ids: # self.descriptions: ade20k_images.append( os.path.join( base_image_dir, "ade20k", "images", "training", "{}.jpg".format(image_id), ) ) ade20k_labels = [ x.replace(".jpg", ".png").replace("images", "annotations") for x in ade20k_images ] print("ade20k: ", len(ade20k_images)) return ade20k_classes, ade20k_images, ade20k_labels def init_paco_lvis(base_image_dir): coco_api_paco_lvis = COCO( os.path.join( base_image_dir, "vlpart", "paco", "annotations", "paco_lvis_v1_train.json" ) ) all_classes = coco_api_paco_lvis.loadCats(coco_api_paco_lvis.getCatIds()) class_map_paco_lvis = {} for cat in all_classes: cat_split = cat["name"].strip().split(":") if len(cat_split) == 1: name = cat_split[0].split("_(")[0] else: assert len(cat_split) == 2 obj, part = cat_split obj = obj.split("_(")[0] part = part.split("_(")[0] name = (obj, part) class_map_paco_lvis[cat["id"]] = name img_ids = coco_api_paco_lvis.getImgIds() print("paco_lvis: ", len(img_ids)) return class_map_paco_lvis, img_ids, coco_api_paco_lvis def init_pascal_part(base_image_dir): coco_api_pascal_part = COCO( os.path.join(base_image_dir, "vlpart", "pascal_part", "train.json") ) all_classes = coco_api_pascal_part.loadCats(coco_api_pascal_part.getCatIds()) class_map_pascal_part = {} for cat in all_classes: cat_main, cat_part = cat["name"].strip().split(":") name = (cat_main, cat_part) class_map_pascal_part[cat["id"]] = name img_ids = coco_api_pascal_part.getImgIds() print("pascal_part: ", len(img_ids)) return class_map_pascal_part, img_ids, coco_api_pascal_part class SemSegDataset(torch.utils.data.Dataset): pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1) pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1) img_size = 1024 ignore_label = 255 def __init__( self, base_image_dir, tokenizer, samples_per_epoch=500 * 8 * 2 * 10, precision: str = "fp32", image_size: int = 224, num_classes_per_sample: int = 3, exclude_val=False, sem_seg_data="ade20k||pascal_part||mapillary", model_type="ori", transform=ResizeLongestSide(1024), ): self.model_type = model_type self.exclude_val = exclude_val self.samples_per_epoch = samples_per_epoch self.num_classes_per_sample = num_classes_per_sample self.base_image_dir = base_image_dir self.tokenizer = tokenizer self.precision = precision self.transform = transform self.image_preprocessor = transforms.Compose([ transforms.ToTensor(), transforms.Resize((image_size, image_size), interpolation=3), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) self.data2list = {} self.data2classes = {} self.sem_seg_datas = sem_seg_data.split("||") for ds in self.sem_seg_datas: classes, images, labels = eval("init_{}".format(ds))(base_image_dir) self.data2list[ds] = (images, labels) self.data2classes[ds] = classes def __len__(self): return self.samples_per_epoch def preprocess(self, x: torch.Tensor) -> torch.Tensor: """Normalize pixel values and pad to a square input.""" if self.model_type=="hq": h, w = x.shape[-2:] padh = self.img_size - h padw = self.img_size - w x = F.pad(x, (0, padw, 0, padh), value=128) # Normalize colors x = (x - self.pixel_mean) / self.pixel_std if self.model_type=="effi": x = F.interpolate(x.unsqueeze(0), (self.img_size, self.img_size), mode="bilinear").squeeze(0) else: # Pad h, w = x.shape[-2:] padh = self.img_size - h padw = self.img_size - w x = F.pad(x, (0, padw, 0, padh)) return x def __getitem__(self, idx): ds = random.randint(0, len(self.sem_seg_datas) - 1) ds = self.sem_seg_datas[ds] if ds in ["pascal_part"]: class_map = self.data2classes[ds] img_ids, coco_api = self.data2list[ds] idx = random.randint(0, len(img_ids) - 1) img_id = img_ids[idx] image_info = coco_api.loadImgs([img_id])[0] file_name = image_info["file_name"] file_name = os.path.join( "VOCdevkit", "VOC2010", "JPEGImages", file_name ) image_path = os.path.join(self.base_image_dir, "vlpart", ds, file_name) image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # preprocess image for evf image_evf = self.image_preprocessor(image) image = self.transform.apply_image(image) # preprocess image for sam resize = image.shape[:2] annIds = coco_api.getAnnIds(imgIds=image_info["id"]) anns = coco_api.loadAnns(annIds) if len(anns) == 0: return self.__getitem__(0) if len(anns) >= self.num_classes_per_sample: sampled_anns = np.random.choice( anns, size=self.num_classes_per_sample, replace=False ).tolist() else: sampled_anns = anns sampled_classes = [] for ann in sampled_anns: sampled_cls = class_map[ann["category_id"]] if isinstance(sampled_cls, tuple): obj, part = sampled_cls if random.random() < 0.5: name = obj + " " + part else: name = "the {} of the {}".format(part, obj) else: name = sampled_cls sampled_classes.append(name) elif ds in ["ade20k", "mapillary"]: image, labels = self.data2list[ds] idx = random.randint(0, len(image) - 1) image_path = image[idx] label_path = labels[idx] label = Image.open(label_path) label = np.array(label) if ds == "ade20k": label[label == 0] = 255 label -= 1 label[label == 254] = 255 img = cv2.imread(image_path) image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # preprocess image for evf image_evf = self.image_preprocessor(image) image = self.transform.apply_image(image) # preprocess image for sam resize = image.shape[:2] unique_label = np.unique(label).tolist() if 255 in unique_label: unique_label.remove(255) if len(unique_label) == 0: return self.__getitem__(0) classes = [self.data2classes[ds][class_id] for class_id in unique_label] if len(classes) >= self.num_classes_per_sample: sampled_classes = np.random.choice( classes, size=self.num_classes_per_sample, replace=False ).tolist() else: sampled_classes = classes class_ids = [] for sampled_cls in sampled_classes: assert len(sampled_cls.split("||")) == 1 if ds in ["paco_lvis", "pascal_part"]: continue class_id = self.data2classes[ds].tolist().index(sampled_cls) class_ids.append(class_id) image = self.preprocess(torch.from_numpy(image).permute(2, 0, 1).contiguous()) if ds in ["pascal_part"]: masks = [] for ann in sampled_anns: try: masks.append(coco_api.annToMask(ann)) except Exception as e: print(e) return self.__getitem__(0) masks = np.stack(masks, axis=0) masks = torch.from_numpy(masks) label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label else: label = torch.from_numpy(label).long() masks = [] for class_id in class_ids: masks.append(label == class_id) masks = torch.stack(masks, dim=0) # sampled_classes = ["all "+_ for _ in sampled_classes] return ( image_path, image, image_evf, masks, label, resize, sampled_classes, )