Kevin L
Initial commit
5fd3fc6
# Partially taken from STM's dataloader
import os
from os import path
import torch
import torch.nn.functional as F
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
import random
from dataset.range_transform import im_normalization
class YouTubeVOSTestDataset(Dataset):
def __init__(self, data_root, split):
self.image_dir = path.join(data_root, 'vos', 'all_frames', split, 'JPEGImages')
self.mask_dir = path.join(data_root, 'vos', split, 'Annotations')
self.videos = []
self.shape = {}
self.frames = {}
vid_list = sorted(os.listdir(self.image_dir))
# Pre-reading
for vid in vid_list:
frames = sorted(os.listdir(os.path.join(self.image_dir, vid)))
self.frames[vid] = frames
self.videos.append(vid)
first_mask = os.listdir(path.join(self.mask_dir, vid))[0]
_mask = np.array(Image.open(path.join(self.mask_dir, vid, first_mask)).convert("P"))
self.shape[vid] = np.shape(_mask)
self.im_transform = transforms.Compose([
transforms.ToTensor(),
im_normalization,
])
# From STM's code
def To_onehot(self, mask, labels):
M = np.zeros((len(labels), mask.shape[0], mask.shape[1]), dtype=np.uint8)
for k, l in enumerate(labels):
M[k] = (mask == l).astype(np.uint8)
return M
def All_to_onehot(self, masks, labels):
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
for n in range(masks.shape[0]):
Ms[:,n] = self.To_onehot(masks[n], labels)
return Ms
def __getitem__(self, idx):
video = self.videos[idx]
info = {}
info['name'] = video
info['num_objects'] = 0
info['frames'] = self.frames[video]
info['size'] = self.shape[video] # Real sizes
info['gt_obj'] = {} # Frames with labelled objects
vid_im_path = path.join(self.image_dir, video)
vid_gt_path = path.join(self.mask_dir, video)
frames = self.frames[video]
images = []
masks = []
for i, f in enumerate(frames):
img = Image.open(path.join(vid_im_path, f)).convert('RGB')
images.append(self.im_transform(img))
mask_file = path.join(vid_gt_path, f.replace('.jpg','.png'))
if path.exists(mask_file):
masks.append(np.array(Image.open(mask_file).convert('P'), dtype=np.uint8))
this_labels = np.unique(masks[-1])
this_labels = this_labels[this_labels!=0]
info['gt_obj'][i] = this_labels
else:
# Mask not exists -> nothing in it
masks.append(np.zeros(self.shape[video]))
images = torch.stack(images, 0)
masks = np.stack(masks, 0)
# Construct the forward and backward mapping table for labels
labels = np.unique(masks).astype(np.uint8)
labels = labels[labels!=0]
info['label_convert'] = {}
info['label_backward'] = {}
idx = 1
for l in labels:
info['label_convert'][l] = idx
info['label_backward'][idx] = l
idx += 1
masks = torch.from_numpy(self.All_to_onehot(masks, labels)).float()
# images = images.unsqueeze(0)
masks = masks.unsqueeze(2)
# Resize to 480p
h, w = masks.shape[-2:]
if h > w:
new_size = (h*480//w, 480)
else:
new_size = (480, w*480//h)
images = F.interpolate(images, size=new_size, mode='bicubic', align_corners=False)
masks = F.interpolate(masks, size=(1, *new_size), mode='nearest')
info['labels'] = labels
data = {
'rgb': images,
'gt': masks,
'info': info,
}
return data
def __len__(self):
return len(self.videos)