# Partially taken from STM's dataloader import os from os import path import torch from torch.utils.data.dataset import Dataset from torchvision import transforms from PIL import Image import numpy as np import random from collections import defaultdict from dataset.range_transform import im_normalization, im_mean from dataset.reseed import reseed class FusionDataset(Dataset): def __init__(self, im_root, gt_root, fd_root): """ fd_root: Root to fusion_data/davis or fusion_data/bl Lots of varables here! See the return dict (at the end) for some comments """ self.im_root = im_root self.gt_root = gt_root self.fd_root = fd_root self.videos = [] self.frames = {} self.vid_to_instance = defaultdict(list) vid_list = sorted(os.listdir(self.im_root)) for vid in vid_list: frames = sorted(os.listdir(os.path.join(self.im_root, vid))) self.frames[vid] = frames self.videos.append(vid) total_fuse_vid = 0 fuse_list = sorted(os.listdir(self.fd_root)) # run-level - different parameters for folder in fuse_list: folder_path = path.join(self.fd_root, folder) video_list = sorted(os.listdir(folder_path)) # video level - different videos for vid in video_list: video_path = path.join(self.fd_root, folder, vid) self.vid_to_instance[vid].append(video_path) total_fuse_vid += 1 # Filter out videos with no out self.videos = [v for v in self.videos if v in self.vid_to_instance] print('%d out of %d videos accepted.' % (len(self.videos), len(vid_list))) print('%d fusion videos accepted' % (total_fuse_vid)) self.im_dual_transform = transforms.Compose([ # transforms.RandomAffine(degrees=30, shear=10, fillcolor=im_mean, resample=Image.BILINEAR), transforms.RandomHorizontalFlip(), # transforms.RandomResizedCrop((384, 384), scale=(0.34,1.0), ratio=(0.9,1.1), interpolation=Image.BILINEAR), transforms.RandomCrop(384), transforms.ColorJitter(0.1, 0.03, 0.03, 0.01), ]) self.gt_dual_transform = transforms.Compose([ # transforms.RandomAffine(degrees=30, shear=10, fillcolor=0, resample=Image.NEAREST), transforms.RandomHorizontalFlip(), # transforms.RandomResizedCrop((384, 384), scale=(0.34,1.0), ratio=(0.9,1.1), interpolation=Image.NEAREST), transforms.RandomCrop(384), ]) self.sg_dual_transform = transforms.Compose([ # transforms.RandomAffine(degrees=30, shear=10, fillcolor=0, resample=Image.BILINEAR), transforms.RandomHorizontalFlip(), # transforms.RandomResizedCrop((384, 384), scale=(0.34,1.0), ratio=(0.9,1.1), interpolation=Image.BILINEAR), transforms.RandomCrop(384), ]) # Final transform without randomness self.final_im_transform = transforms.Compose([ im_normalization, ]) def __getitem__(self, idx): info = {} info['frames'] = [] # Appended with actual frames # Try a few times max_trial = 20 for trials in range(max_trial): if trials < 5: video = self.videos[idx % len(self.videos)] else: video = np.random.choice(self.videos) vid_im_path = path.join(self.im_root, video) vid_gt_path = path.join(self.gt_root, video) info['name'] = video frames = self.frames[video] sequence_seed = np.random.randint(2147483647) video_path = self.vid_to_instance[video][np.random.choice(range(len(self.vid_to_instance[video])))] # Randomly pick the reference frames and object all_ref = os.listdir(video_path) first_ref = np.random.choice(all_ref) tar_obj = np.random.choice(os.listdir(path.join(video_path, first_ref))) tar_frame = np.random.choice(os.listdir(path.join(video_path, first_ref, tar_obj))) tar_obj_int = int(tar_obj) tar_frame_int = int(tar_frame[:-4]) # Pick the second reference frame src2_ref_options = [] for r in all_ref: # No self-referecne if r == first_ref: continue # We need the second reference frame to be visible from the first if not path.exists(path.join(video_path, first_ref, tar_obj, r+'.png')): continue # We need the target object to exist if path.exists(path.join(video_path, r, tar_obj, tar_frame)): src2_ref_options.append(r) if len(src2_ref_options)>0: secon_ref = np.random.choice(src2_ref_options) else: continue # Pick another object that is valid in both reference frame sec_obj_options = [obj for obj in os.listdir(path.join(video_path, first_ref)) if path.exists(path.join(video_path, first_ref, obj, tar_frame)) and path.exists(path.join(video_path, secon_ref, obj, tar_frame)) and obj != tar_obj] if len(sec_obj_options) == 0: sec_obj = -1 else: sec_obj = np.random.choice(sec_obj_options) sec_obj_int = int(sec_obj) # Compute distance from reference frame to target frame dist_1 = abs(int(first_ref)-tar_frame_int) / abs(int(first_ref)-int(secon_ref)) dist_2 = abs(int(secon_ref)-tar_frame_int) / abs(int(first_ref)-int(secon_ref)) png_name = '%05d'%tar_frame_int + '.png' jpg_name = '%05d'%tar_frame_int + '.jpg' src2_ref_png_name = '%05d'%int(secon_ref) + '.png' src2_ref_jpg_name = '%05d'%int(secon_ref) + '.jpg' src1_seg = Image.open(path.join(video_path, first_ref, tar_obj, png_name)).convert('L') src2_seg = Image.open(path.join(video_path, secon_ref, tar_obj, png_name)).convert('L') # Transform these first two reseed(sequence_seed) src1_seg = np.array(self.sg_dual_transform(src1_seg))[:,:,np.newaxis] reseed(sequence_seed) src2_seg = np.array(self.sg_dual_transform(src2_seg))[:,:,np.newaxis] diff = np.abs(src1_seg.astype(np.float32) - src2_seg.astype(np.float32)) > (255*0.1) diff = diff.astype(np.uint8) usable_i, usable_j = np.nonzero(diff[:,:,0]) if trials 0.5] = 1 cls_gt[gt_mask2[0] > 0.5] = 2 data = { # Target frame is defined to be the frame that requires fusion 'rgb': im, # Target frame image 'cls_gt': cls_gt, # Target frame ground truth in int format # First object 'gt': gt_mask, # GT mask of object 1 at the target frame 'seg1': src1_seg, # Propagated mask from reference 1 of object 1 at the target frame 'seg2': src2_seg, # Propagated mask from reference 2 of object 1 at the target frame 'src2_ref': src2_ref_seg, # Propagated mask from reference 1 of object 1 at reference 2 'src2_ref_gt': src2_ref_mask, # GT mask of object 1 at reference 2 # Second object 'gt2': gt_mask2, # GT mask of object 2 at the target frame 'seg12': src1_seg2, # ... of object 2 ... 'seg22': src2_seg2, # ... of object 2 ... 'src2_ref2': src2_ref_seg2, # ... of object 2 ... 'src2_ref_gt2': src2_ref_mask2, # ... of object 2 ... 'src2_ref_im': src2_ref_im, # Image at reference 2 'dist': dist, 'selector': selector, 'info': info, } return data def __len__(self): return len(self.videos) * 100