from __future__ import division import os import shutil import json import cv2 from PIL import Image import numpy as np from torch.utils.data import Dataset from utils.image import _palette class VOSTest(Dataset): def __init__(self, image_root, label_root, seq_name, images, labels, rgb=True, transform=None, single_obj=False, resolution=None): self.image_root = image_root self.label_root = label_root self.seq_name = seq_name self.images = images self.labels = labels self.obj_num = 1 self.num_frame = len(self.images) self.transform = transform self.rgb = rgb self.single_obj = single_obj self.resolution = resolution self.obj_nums = [] self.obj_indices = [] curr_objs = [0] for img_name in self.images: self.obj_nums.append(len(curr_objs) - 1) current_label_name = img_name.split('.')[0] + '.png' if current_label_name in self.labels: current_label = self.read_label(current_label_name) curr_obj = list(np.unique(current_label)) for obj_idx in curr_obj: if obj_idx not in curr_objs: curr_objs.append(obj_idx) self.obj_indices.append(curr_objs.copy()) self.obj_nums[0] = self.obj_nums[1] def __len__(self): return len(self.images) def read_image(self, idx): img_name = self.images[idx] img_path = os.path.join(self.image_root, self.seq_name, img_name) img = cv2.imread(img_path) img = np.array(img, dtype=np.float32) if self.rgb: img = img[:, :, [2, 1, 0]] return img def read_label(self, label_name, squeeze_idx=None): label_path = os.path.join(self.label_root, self.seq_name, label_name) label = Image.open(label_path) label = np.array(label, dtype=np.uint8) if self.single_obj: label = (label > 0).astype(np.uint8) elif squeeze_idx is not None: squeezed_label = label * 0 for idx in range(len(squeeze_idx)): obj_id = squeeze_idx[idx] if obj_id == 0: continue mask = label == obj_id squeezed_label += (mask * idx).astype(np.uint8) label = squeezed_label return label def __getitem__(self, idx): img_name = self.images[idx] current_img = self.read_image(idx) height, width, channels = current_img.shape if self.resolution is not None: width = int(np.ceil( float(width) * self.resolution / float(height))) height = int(self.resolution) current_label_name = img_name.split('.')[0] + '.png' obj_num = self.obj_nums[idx] obj_idx = self.obj_indices[idx] if current_label_name in self.labels: current_label = self.read_label(current_label_name, obj_idx) sample = { 'current_img': current_img, 'current_label': current_label } else: sample = {'current_img': current_img} sample['meta'] = { 'seq_name': self.seq_name, 'frame_num': self.num_frame, 'obj_num': obj_num, 'current_name': img_name, 'height': height, 'width': width, 'flip': False, 'obj_idx': obj_idx } if self.transform is not None: sample = self.transform(sample) return sample class YOUTUBEVOS_Test(object): def __init__(self, root='./datasets/YTB', year=2018, split='val', transform=None, rgb=True, result_root=None): if split == 'val': split = 'valid' root = os.path.join(root, str(year), split) self.db_root_dir = root self.result_root = result_root self.rgb = rgb self.transform = transform self.seq_list_file = os.path.join(self.db_root_dir, 'meta.json') self._check_preprocess() self.seqs = list(self.ann_f.keys()) self.image_root = os.path.join(root, 'JPEGImages') self.label_root = os.path.join(root, 'Annotations') def __len__(self): return len(self.seqs) def __getitem__(self, idx): seq_name = self.seqs[idx] data = self.ann_f[seq_name]['objects'] obj_names = list(data.keys()) images = [] labels = [] for obj_n in obj_names: images += map(lambda x: x + '.jpg', list(data[obj_n]["frames"])) labels.append(data[obj_n]["frames"][0] + '.png') images = np.sort(np.unique(images)) labels = np.sort(np.unique(labels)) try: if not os.path.isfile( os.path.join(self.result_root, seq_name, labels[0])): if not os.path.exists(os.path.join(self.result_root, seq_name)): os.makedirs(os.path.join(self.result_root, seq_name)) shutil.copy( os.path.join(self.label_root, seq_name, labels[0]), os.path.join(self.result_root, seq_name, labels[0])) except Exception as inst: print(inst) print('Failed to create a result folder for sequence {}.'.format( seq_name)) seq_dataset = VOSTest(self.image_root, self.label_root, seq_name, images, labels, transform=self.transform, rgb=self.rgb) return seq_dataset def _check_preprocess(self): _seq_list_file = self.seq_list_file if not os.path.isfile(_seq_list_file): print(_seq_list_file) return False else: self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] return True class YOUTUBEVOS_DenseTest(object): def __init__(self, root='./datasets/YTB', year=2018, split='val', transform=None, rgb=True, result_root=None): if split == 'val': split = 'valid' root_sparse = os.path.join(root, str(year), split) root_dense = root_sparse + '_all_frames' self.db_root_dir = root_dense self.result_root = result_root self.rgb = rgb self.transform = transform self.seq_list_file = os.path.join(root_sparse, 'meta.json') self._check_preprocess() self.seqs = list(self.ann_f.keys()) self.image_root = os.path.join(root_dense, 'JPEGImages') self.label_root = os.path.join(root_sparse, 'Annotations') def __len__(self): return len(self.seqs) def __getitem__(self, idx): seq_name = self.seqs[idx] data = self.ann_f[seq_name]['objects'] obj_names = list(data.keys()) images_sparse = [] for obj_n in obj_names: images_sparse += map(lambda x: x + '.jpg', list(data[obj_n]["frames"])) images_sparse = np.sort(np.unique(images_sparse)) images = np.sort( list(os.listdir(os.path.join(self.image_root, seq_name)))) start_img = images_sparse[0] end_img = images_sparse[-1] for start_idx in range(len(images)): if start_img in images[start_idx]: break for end_idx in range(len(images))[::-1]: if end_img in images[end_idx]: break images = images[start_idx:(end_idx + 1)] labels = np.sort( list(os.listdir(os.path.join(self.label_root, seq_name)))) try: if not os.path.isfile( os.path.join(self.result_root, seq_name, labels[0])): if not os.path.exists(os.path.join(self.result_root, seq_name)): os.makedirs(os.path.join(self.result_root, seq_name)) shutil.copy( os.path.join(self.label_root, seq_name, labels[0]), os.path.join(self.result_root, seq_name, labels[0])) except Exception as inst: print(inst) print('Failed to create a result folder for sequence {}.'.format( seq_name)) seq_dataset = VOSTest(self.image_root, self.label_root, seq_name, images, labels, transform=self.transform, rgb=self.rgb) seq_dataset.images_sparse = images_sparse return seq_dataset def _check_preprocess(self): _seq_list_file = self.seq_list_file if not os.path.isfile(_seq_list_file): print(_seq_list_file) return False else: self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] return True class DAVIS_Test(object): def __init__(self, split=['val'], root='./DAVIS', year=2017, transform=None, rgb=True, full_resolution=False, result_root=None): self.transform = transform self.rgb = rgb self.result_root = result_root if year == 2016: self.single_obj = True else: self.single_obj = False if full_resolution: resolution = 'Full-Resolution' else: resolution = '480p' self.image_root = os.path.join(root, 'JPEGImages', resolution) self.label_root = os.path.join(root, 'Annotations', resolution) seq_names = [] for spt in split: if spt == 'test': spt = 'test-dev' with open(os.path.join(root, 'ImageSets', str(year), spt + '.txt')) as f: seqs_tmp = f.readlines() seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) seq_names.extend(seqs_tmp) self.seqs = list(np.unique(seq_names)) def __len__(self): return len(self.seqs) def __getitem__(self, idx): seq_name = self.seqs[idx] images = list( np.sort(os.listdir(os.path.join(self.image_root, seq_name)))) labels = [images[0].replace('jpg', 'png')] if not os.path.isfile( os.path.join(self.result_root, seq_name, labels[0])): seq_result_folder = os.path.join(self.result_root, seq_name) try: if not os.path.exists(seq_result_folder): os.makedirs(seq_result_folder) except Exception as inst: print(inst) print( 'Failed to create a result folder for sequence {}.'.format( seq_name)) source_label_path = os.path.join(self.label_root, seq_name, labels[0]) result_label_path = os.path.join(self.result_root, seq_name, labels[0]) if self.single_obj: label = Image.open(source_label_path) label = np.array(label, dtype=np.uint8) label = (label > 0).astype(np.uint8) label = Image.fromarray(label).convert('P') label.putpalette(_palette) label.save(result_label_path) else: shutil.copy(source_label_path, result_label_path) seq_dataset = VOSTest(self.image_root, self.label_root, seq_name, images, labels, transform=self.transform, rgb=self.rgb, single_obj=self.single_obj, resolution=480) return seq_dataset class _EVAL_TEST(Dataset): def __init__(self, transform, seq_name): self.seq_name = seq_name self.num_frame = 10 self.transform = transform def __len__(self): return self.num_frame def __getitem__(self, idx): current_frame_obj_num = 2 height = 400 width = 400 img_name = 'test{}.jpg'.format(idx) current_img = np.zeros((height, width, 3)).astype(np.float32) if idx == 0: current_label = (current_frame_obj_num * np.ones( (height, width))).astype(np.uint8) sample = { 'current_img': current_img, 'current_label': current_label } else: sample = {'current_img': current_img} sample['meta'] = { 'seq_name': self.seq_name, 'frame_num': self.num_frame, 'obj_num': current_frame_obj_num, 'current_name': img_name, 'height': height, 'width': width, 'flip': False } if self.transform is not None: sample = self.transform(sample) return sample class EVAL_TEST(object): def __init__(self, transform=None, result_root=None): self.transform = transform self.result_root = result_root self.seqs = ['test1', 'test2', 'test3'] def __len__(self): return len(self.seqs) def __getitem__(self, idx): seq_name = self.seqs[idx] if not os.path.exists(os.path.join(self.result_root, seq_name)): os.makedirs(os.path.join(self.result_root, seq_name)) seq_dataset = _EVAL_TEST(self.transform, seq_name) return seq_dataset