Spaces:
Running
Running
# Partially taken from STM's dataloader | |
import os | |
from os import path | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data.dataset import Dataset | |
from torchvision import transforms | |
from PIL import Image | |
import numpy as np | |
import random | |
from dataset.range_transform import im_normalization | |
class YouTubeVOSTestDataset(Dataset): | |
def __init__(self, data_root, split): | |
self.image_dir = path.join(data_root, 'vos', 'all_frames', split, 'JPEGImages') | |
self.mask_dir = path.join(data_root, 'vos', split, 'Annotations') | |
self.videos = [] | |
self.shape = {} | |
self.frames = {} | |
vid_list = sorted(os.listdir(self.image_dir)) | |
# Pre-reading | |
for vid in vid_list: | |
frames = sorted(os.listdir(os.path.join(self.image_dir, vid))) | |
self.frames[vid] = frames | |
self.videos.append(vid) | |
first_mask = os.listdir(path.join(self.mask_dir, vid))[0] | |
_mask = np.array(Image.open(path.join(self.mask_dir, vid, first_mask)).convert("P")) | |
self.shape[vid] = np.shape(_mask) | |
self.im_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
im_normalization, | |
]) | |
# From STM's code | |
def To_onehot(self, mask, labels): | |
M = np.zeros((len(labels), mask.shape[0], mask.shape[1]), dtype=np.uint8) | |
for k, l in enumerate(labels): | |
M[k] = (mask == l).astype(np.uint8) | |
return M | |
def All_to_onehot(self, masks, labels): | |
Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) | |
for n in range(masks.shape[0]): | |
Ms[:,n] = self.To_onehot(masks[n], labels) | |
return Ms | |
def __getitem__(self, idx): | |
video = self.videos[idx] | |
info = {} | |
info['name'] = video | |
info['num_objects'] = 0 | |
info['frames'] = self.frames[video] | |
info['size'] = self.shape[video] # Real sizes | |
info['gt_obj'] = {} # Frames with labelled objects | |
vid_im_path = path.join(self.image_dir, video) | |
vid_gt_path = path.join(self.mask_dir, video) | |
frames = self.frames[video] | |
images = [] | |
masks = [] | |
for i, f in enumerate(frames): | |
img = Image.open(path.join(vid_im_path, f)).convert('RGB') | |
images.append(self.im_transform(img)) | |
mask_file = path.join(vid_gt_path, f.replace('.jpg','.png')) | |
if path.exists(mask_file): | |
masks.append(np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)) | |
this_labels = np.unique(masks[-1]) | |
this_labels = this_labels[this_labels!=0] | |
info['gt_obj'][i] = this_labels | |
else: | |
# Mask not exists -> nothing in it | |
masks.append(np.zeros(self.shape[video])) | |
images = torch.stack(images, 0) | |
masks = np.stack(masks, 0) | |
# Construct the forward and backward mapping table for labels | |
labels = np.unique(masks).astype(np.uint8) | |
labels = labels[labels!=0] | |
info['label_convert'] = {} | |
info['label_backward'] = {} | |
idx = 1 | |
for l in labels: | |
info['label_convert'][l] = idx | |
info['label_backward'][idx] = l | |
idx += 1 | |
masks = torch.from_numpy(self.All_to_onehot(masks, labels)).float() | |
# images = images.unsqueeze(0) | |
masks = masks.unsqueeze(2) | |
# Resize to 480p | |
h, w = masks.shape[-2:] | |
if h > w: | |
new_size = (h*480//w, 480) | |
else: | |
new_size = (480, w*480//h) | |
images = F.interpolate(images, size=new_size, mode='bicubic', align_corners=False) | |
masks = F.interpolate(masks, size=(1, *new_size), mode='nearest') | |
info['labels'] = labels | |
data = { | |
'rgb': images, | |
'gt': masks, | |
'info': info, | |
} | |
return data | |
def __len__(self): | |
return len(self.videos) |