|
from __future__ import division |
|
import os |
|
from glob import glob |
|
import json |
|
import random |
|
import cv2 |
|
from PIL import Image |
|
|
|
import numpy as np |
|
import torch |
|
from torch.utils.data import Dataset |
|
import torchvision.transforms as TF |
|
|
|
import dataloaders.image_transforms as IT |
|
|
|
cv2.setNumThreads(0) |
|
|
|
|
|
def _get_images(sample): |
|
return [sample['ref_img'], sample['prev_img']] + sample['curr_img'] |
|
|
|
|
|
def _get_labels(sample): |
|
return [sample['ref_label'], sample['prev_label']] + sample['curr_label'] |
|
|
|
|
|
def _merge_sample(sample1, sample2, min_obj_pixels=100, max_obj_n=10): |
|
|
|
sample1_images = _get_images(sample1) |
|
sample2_images = _get_images(sample2) |
|
|
|
sample1_labels = _get_labels(sample1) |
|
sample2_labels = _get_labels(sample2) |
|
|
|
obj_idx = torch.arange(0, max_obj_n * 2 + 1).view(max_obj_n * 2 + 1, 1, 1) |
|
selected_idx = None |
|
selected_obj = None |
|
|
|
all_img = [] |
|
all_mask = [] |
|
for idx, (s1_img, s2_img, s1_label, s2_label) in enumerate( |
|
zip(sample1_images, sample2_images, sample1_labels, |
|
sample2_labels)): |
|
s2_fg = (s2_label > 0).float() |
|
s2_bg = 1 - s2_fg |
|
merged_img = s1_img * s2_bg + s2_img * s2_fg |
|
merged_mask = s1_label * s2_bg.long() + ( |
|
(s2_label + max_obj_n) * s2_fg.long()) |
|
merged_mask = (merged_mask == obj_idx).float() |
|
if idx == 0: |
|
after_merge_pixels = merged_mask.sum(dim=(1, 2), keepdim=True) |
|
selected_idx = after_merge_pixels > min_obj_pixels |
|
selected_idx[0] = True |
|
obj_num = selected_idx.sum().int().item() - 1 |
|
selected_idx = selected_idx.expand(-1, |
|
s1_label.size()[1], |
|
s1_label.size()[2]) |
|
if obj_num > max_obj_n: |
|
selected_obj = list(range(1, obj_num + 1)) |
|
random.shuffle(selected_obj) |
|
selected_obj = [0] + selected_obj[:max_obj_n] |
|
|
|
merged_mask = merged_mask[selected_idx].view(obj_num + 1, |
|
s1_label.size()[1], |
|
s1_label.size()[2]) |
|
if obj_num > max_obj_n: |
|
merged_mask = merged_mask[selected_obj] |
|
merged_mask[0] += 0.1 |
|
merged_mask = torch.argmax(merged_mask, dim=0, keepdim=True).long() |
|
|
|
all_img.append(merged_img) |
|
all_mask.append(merged_mask) |
|
|
|
sample = { |
|
'ref_img': all_img[0], |
|
'prev_img': all_img[1], |
|
'curr_img': all_img[2:], |
|
'ref_label': all_mask[0], |
|
'prev_label': all_mask[1], |
|
'curr_label': all_mask[2:] |
|
} |
|
sample['meta'] = sample1['meta'] |
|
sample['meta']['obj_num'] = min(obj_num, max_obj_n) |
|
return sample |
|
|
|
|
|
class StaticTrain(Dataset): |
|
def __init__(self, |
|
root, |
|
output_size, |
|
seq_len=5, |
|
max_obj_n=10, |
|
dynamic_merge=True, |
|
merge_prob=1.0, |
|
aug_type='v1'): |
|
self.root = root |
|
self.clip_n = seq_len |
|
self.output_size = output_size |
|
self.max_obj_n = max_obj_n |
|
|
|
self.dynamic_merge = dynamic_merge |
|
self.merge_prob = merge_prob |
|
|
|
self.img_list = list() |
|
self.mask_list = list() |
|
|
|
dataset_list = list() |
|
lines = ['COCO', 'ECSSD', 'MSRA10K', 'PASCAL-S', 'PASCALVOC2012'] |
|
for line in lines: |
|
dataset_name = line.strip() |
|
|
|
img_dir = os.path.join(root, 'JPEGImages', dataset_name) |
|
mask_dir = os.path.join(root, 'Annotations', dataset_name) |
|
|
|
img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + \ |
|
sorted(glob(os.path.join(img_dir, '*.png'))) |
|
mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) |
|
|
|
if len(img_list) > 0: |
|
if len(img_list) == len(mask_list): |
|
dataset_list.append(dataset_name) |
|
self.img_list += img_list |
|
self.mask_list += mask_list |
|
print(f'\t{dataset_name}: {len(img_list)} imgs.') |
|
else: |
|
print( |
|
f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.' |
|
) |
|
else: |
|
print( |
|
f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.') |
|
|
|
print( |
|
f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.' |
|
) |
|
|
|
self.aug_type = aug_type |
|
|
|
self.pre_random_horizontal_flip = IT.RandomHorizontalFlip(0.5) |
|
|
|
self.random_horizontal_flip = IT.RandomHorizontalFlip(0.3) |
|
|
|
if self.aug_type == 'v1': |
|
self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03) |
|
elif self.aug_type == 'v2': |
|
self.color_jitter = TF.RandomApply( |
|
[TF.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8) |
|
self.gray_scale = TF.RandomGrayscale(p=0.2) |
|
self.blur = TF.RandomApply([IT.GaussianBlur([.1, 2.])], p=0.3) |
|
else: |
|
assert NotImplementedError |
|
|
|
self.random_affine = IT.RandomAffine(degrees=20, |
|
translate=(0.1, 0.1), |
|
scale=(0.9, 1.1), |
|
shear=10, |
|
resample=Image.BICUBIC, |
|
fillcolor=(124, 116, 104)) |
|
base_ratio = float(output_size[1]) / output_size[0] |
|
self.random_resize_crop = IT.RandomResizedCrop( |
|
output_size, (0.8, 1), |
|
ratio=(base_ratio * 3. / 4., base_ratio * 4. / 3.), |
|
interpolation=Image.BICUBIC) |
|
self.to_tensor = TF.ToTensor() |
|
self.to_onehot = IT.ToOnehot(max_obj_n, shuffle=True) |
|
self.normalize = TF.Normalize((0.485, 0.456, 0.406), |
|
(0.229, 0.224, 0.225)) |
|
|
|
def __len__(self): |
|
return len(self.img_list) |
|
|
|
def load_image_in_PIL(self, path, mode='RGB'): |
|
img = Image.open(path) |
|
img.load() |
|
return img.convert(mode) |
|
|
|
def sample_sequence(self, idx): |
|
img_pil = self.load_image_in_PIL(self.img_list[idx], 'RGB') |
|
mask_pil = self.load_image_in_PIL(self.mask_list[idx], 'P') |
|
|
|
frames = [] |
|
masks = [] |
|
|
|
img_pil, mask_pil = self.pre_random_horizontal_flip(img_pil, mask_pil) |
|
|
|
|
|
for i in range(self.clip_n): |
|
img, mask = img_pil, mask_pil |
|
|
|
if i > 0: |
|
img, mask = self.random_horizontal_flip(img, mask) |
|
img, mask = self.random_affine(img, mask) |
|
|
|
img = self.color_jitter(img) |
|
|
|
img, mask = self.random_resize_crop(img, mask) |
|
|
|
if self.aug_type == 'v2': |
|
img = self.gray_scale(img) |
|
img = self.blur(img) |
|
|
|
mask = np.array(mask, np.uint8) |
|
|
|
if i == 0: |
|
mask, obj_list = self.to_onehot(mask) |
|
obj_num = len(obj_list) |
|
else: |
|
mask, _ = self.to_onehot(mask, obj_list) |
|
|
|
mask = torch.argmax(mask, dim=0, keepdim=True) |
|
|
|
frames.append(self.normalize(self.to_tensor(img))) |
|
masks.append(mask) |
|
|
|
sample = { |
|
'ref_img': frames[0], |
|
'prev_img': frames[1], |
|
'curr_img': frames[2:], |
|
'ref_label': masks[0], |
|
'prev_label': masks[1], |
|
'curr_label': masks[2:] |
|
} |
|
sample['meta'] = { |
|
'seq_name': self.img_list[idx], |
|
'frame_num': 1, |
|
'obj_num': obj_num |
|
} |
|
|
|
return sample |
|
|
|
def __getitem__(self, idx): |
|
sample1 = self.sample_sequence(idx) |
|
|
|
if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 |
|
or random.random() < self.merge_prob): |
|
rand_idx = np.random.randint(len(self.img_list)) |
|
while (rand_idx == idx): |
|
rand_idx = np.random.randint(len(self.img_list)) |
|
|
|
sample2 = self.sample_sequence(rand_idx) |
|
|
|
sample = self.merge_sample(sample1, sample2) |
|
else: |
|
sample = sample1 |
|
|
|
return sample |
|
|
|
def merge_sample(self, sample1, sample2, min_obj_pixels=100): |
|
return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) |
|
|
|
|
|
class VOSTrain(Dataset): |
|
def __init__(self, |
|
image_root, |
|
label_root, |
|
imglistdic, |
|
transform=None, |
|
rgb=True, |
|
repeat_time=1, |
|
rand_gap=3, |
|
seq_len=5, |
|
rand_reverse=True, |
|
dynamic_merge=True, |
|
enable_prev_frame=False, |
|
merge_prob=0.3, |
|
max_obj_n=10): |
|
self.image_root = image_root |
|
self.label_root = label_root |
|
self.rand_gap = rand_gap |
|
self.seq_len = seq_len |
|
self.rand_reverse = rand_reverse |
|
self.repeat_time = repeat_time |
|
self.transform = transform |
|
self.dynamic_merge = dynamic_merge |
|
self.merge_prob = merge_prob |
|
self.enable_prev_frame = enable_prev_frame |
|
self.max_obj_n = max_obj_n |
|
self.rgb = rgb |
|
self.imglistdic = imglistdic |
|
self.seqs = list(self.imglistdic.keys()) |
|
print('Video Num: {} X {}'.format(len(self.seqs), self.repeat_time)) |
|
|
|
def __len__(self): |
|
return int(len(self.seqs) * self.repeat_time) |
|
|
|
def reverse_seq(self, imagelist, lablist): |
|
if np.random.randint(2) == 1: |
|
imagelist = imagelist[::-1] |
|
lablist = lablist[::-1] |
|
return imagelist, lablist |
|
|
|
def get_ref_index(self, |
|
seqname, |
|
lablist, |
|
objs, |
|
min_fg_pixels=200, |
|
max_try=5): |
|
bad_indices = [] |
|
for _ in range(max_try): |
|
ref_index = np.random.randint(len(lablist)) |
|
if ref_index in bad_indices: |
|
continue |
|
ref_label = Image.open( |
|
os.path.join(self.label_root, seqname, lablist[ref_index])) |
|
ref_label = np.array(ref_label, dtype=np.uint8) |
|
ref_objs = list(np.unique(ref_label)) |
|
is_consistent = True |
|
for obj in ref_objs: |
|
if obj == 0: |
|
continue |
|
if obj not in objs: |
|
is_consistent = False |
|
xs, ys = np.nonzero(ref_label) |
|
if len(xs) > min_fg_pixels and is_consistent: |
|
break |
|
bad_indices.append(ref_index) |
|
return ref_index |
|
|
|
def get_ref_index_v2(self, |
|
seqname, |
|
lablist, |
|
min_fg_pixels=200, |
|
max_try=20, |
|
total_gap=0): |
|
search_range = len(lablist) - total_gap |
|
if search_range <= 1: |
|
return 0 |
|
bad_indices = [] |
|
for _ in range(max_try): |
|
ref_index = np.random.randint(search_range) |
|
if ref_index in bad_indices: |
|
continue |
|
ref_label = Image.open( |
|
os.path.join(self.label_root, seqname, lablist[ref_index])) |
|
ref_label = np.array(ref_label, dtype=np.uint8) |
|
xs, ys = np.nonzero(ref_label) |
|
if len(xs) > min_fg_pixels: |
|
break |
|
bad_indices.append(ref_index) |
|
return ref_index |
|
|
|
def get_curr_gaps(self, seq_len, max_gap=999, max_try=10): |
|
for _ in range(max_try): |
|
curr_gaps = [] |
|
total_gap = 0 |
|
for _ in range(seq_len): |
|
gap = int(np.random.randint(self.rand_gap) + 1) |
|
total_gap += gap |
|
curr_gaps.append(gap) |
|
if total_gap <= max_gap: |
|
break |
|
return curr_gaps, total_gap |
|
|
|
def get_prev_index(self, lablist, total_gap): |
|
search_range = len(lablist) - total_gap |
|
if search_range > 1: |
|
prev_index = np.random.randint(search_range) |
|
else: |
|
prev_index = 0 |
|
return prev_index |
|
|
|
def check_index(self, total_len, index, allow_reflect=True): |
|
if total_len <= 1: |
|
return 0 |
|
|
|
if index < 0: |
|
if allow_reflect: |
|
index = -index |
|
index = self.check_index(total_len, index, True) |
|
else: |
|
index = 0 |
|
elif index >= total_len: |
|
if allow_reflect: |
|
index = 2 * (total_len - 1) - index |
|
index = self.check_index(total_len, index, True) |
|
else: |
|
index = total_len - 1 |
|
|
|
return index |
|
|
|
def get_curr_indices(self, lablist, prev_index, gaps): |
|
total_len = len(lablist) |
|
curr_indices = [] |
|
now_index = prev_index |
|
for gap in gaps: |
|
now_index += gap |
|
curr_indices.append(self.check_index(total_len, now_index)) |
|
return curr_indices |
|
|
|
def get_image_label(self, seqname, imagelist, lablist, index): |
|
image = cv2.imread( |
|
os.path.join(self.image_root, seqname, imagelist[index])) |
|
image = np.array(image, dtype=np.float32) |
|
|
|
if self.rgb: |
|
image = image[:, :, [2, 1, 0]] |
|
|
|
label = Image.open( |
|
os.path.join(self.label_root, seqname, lablist[index])) |
|
label = np.array(label, dtype=np.uint8) |
|
|
|
return image, label |
|
|
|
def sample_sequence(self, idx): |
|
idx = idx % len(self.seqs) |
|
seqname = self.seqs[idx] |
|
imagelist, lablist = self.imglistdic[seqname] |
|
frame_num = len(imagelist) |
|
if self.rand_reverse: |
|
imagelist, lablist = self.reverse_seq(imagelist, lablist) |
|
|
|
is_consistent = False |
|
max_try = 5 |
|
try_step = 0 |
|
while (is_consistent is False and try_step < max_try): |
|
try_step += 1 |
|
|
|
|
|
curr_gaps, total_gap = self.get_curr_gaps(self.seq_len - 1) |
|
|
|
if self.enable_prev_frame: |
|
|
|
prev_index = self.get_prev_index(lablist, total_gap) |
|
prev_image, prev_label = self.get_image_label( |
|
seqname, imagelist, lablist, prev_index) |
|
prev_objs = list(np.unique(prev_label)) |
|
|
|
|
|
curr_indices = self.get_curr_indices(lablist, prev_index, |
|
curr_gaps) |
|
curr_images, curr_labels, curr_objs = [], [], [] |
|
for curr_index in curr_indices: |
|
curr_image, curr_label = self.get_image_label( |
|
seqname, imagelist, lablist, curr_index) |
|
c_objs = list(np.unique(curr_label)) |
|
curr_images.append(curr_image) |
|
curr_labels.append(curr_label) |
|
curr_objs.extend(c_objs) |
|
|
|
objs = list(np.unique(prev_objs + curr_objs)) |
|
|
|
start_index = prev_index |
|
end_index = max(curr_indices) |
|
|
|
_try_step = 0 |
|
ref_index = self.get_ref_index_v2(seqname, lablist) |
|
while (ref_index > start_index and ref_index <= end_index |
|
and _try_step < max_try): |
|
_try_step += 1 |
|
ref_index = self.get_ref_index_v2(seqname, lablist) |
|
ref_image, ref_label = self.get_image_label( |
|
seqname, imagelist, lablist, ref_index) |
|
ref_objs = list(np.unique(ref_label)) |
|
else: |
|
|
|
ref_index = self.get_ref_index_v2(seqname, lablist) |
|
|
|
ref_image, ref_label = self.get_image_label( |
|
seqname, imagelist, lablist, ref_index) |
|
ref_objs = list(np.unique(ref_label)) |
|
|
|
|
|
curr_indices = self.get_curr_indices(lablist, ref_index, |
|
curr_gaps) |
|
curr_images, curr_labels, curr_objs = [], [], [] |
|
for curr_index in curr_indices: |
|
curr_image, curr_label = self.get_image_label( |
|
seqname, imagelist, lablist, curr_index) |
|
c_objs = list(np.unique(curr_label)) |
|
curr_images.append(curr_image) |
|
curr_labels.append(curr_label) |
|
curr_objs.extend(c_objs) |
|
|
|
objs = list(np.unique(curr_objs)) |
|
prev_image, prev_label = curr_images[0], curr_labels[0] |
|
curr_images, curr_labels = curr_images[1:], curr_labels[1:] |
|
|
|
is_consistent = True |
|
for obj in objs: |
|
if obj == 0: |
|
continue |
|
if obj not in ref_objs: |
|
is_consistent = False |
|
break |
|
|
|
|
|
obj_num = list(np.sort(ref_objs))[-1] |
|
|
|
sample = { |
|
'ref_img': ref_image, |
|
'prev_img': prev_image, |
|
'curr_img': curr_images, |
|
'ref_label': ref_label, |
|
'prev_label': prev_label, |
|
'curr_label': curr_labels |
|
} |
|
sample['meta'] = { |
|
'seq_name': seqname, |
|
'frame_num': frame_num, |
|
'obj_num': obj_num |
|
} |
|
|
|
if self.transform is not None: |
|
sample = self.transform(sample) |
|
|
|
return sample |
|
|
|
def __getitem__(self, idx): |
|
sample1 = self.sample_sequence(idx) |
|
|
|
if self.dynamic_merge and (sample1['meta']['obj_num'] == 0 |
|
or random.random() < self.merge_prob): |
|
rand_idx = np.random.randint(len(self.seqs)) |
|
while (rand_idx == (idx % len(self.seqs))): |
|
rand_idx = np.random.randint(len(self.seqs)) |
|
|
|
sample2 = self.sample_sequence(rand_idx) |
|
|
|
sample = self.merge_sample(sample1, sample2) |
|
else: |
|
sample = sample1 |
|
|
|
return sample |
|
|
|
def merge_sample(self, sample1, sample2, min_obj_pixels=100): |
|
return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n) |
|
|
|
|
|
class DAVIS2017_Train(VOSTrain): |
|
def __init__(self, |
|
split=['train'], |
|
root='./DAVIS', |
|
transform=None, |
|
rgb=True, |
|
repeat_time=1, |
|
full_resolution=True, |
|
year=2017, |
|
rand_gap=3, |
|
seq_len=5, |
|
rand_reverse=True, |
|
dynamic_merge=True, |
|
enable_prev_frame=False, |
|
max_obj_n=10, |
|
merge_prob=0.3): |
|
if full_resolution: |
|
resolution = 'Full-Resolution' |
|
if not os.path.exists(os.path.join(root, 'JPEGImages', |
|
resolution)): |
|
print('No Full-Resolution, use 480p instead.') |
|
resolution = '480p' |
|
else: |
|
resolution = '480p' |
|
image_root = os.path.join(root, 'JPEGImages', resolution) |
|
label_root = os.path.join(root, 'Annotations', resolution) |
|
seq_names = [] |
|
for spt in split: |
|
with open(os.path.join(root, 'ImageSets', str(year), |
|
spt + '.txt')) as f: |
|
seqs_tmp = f.readlines() |
|
seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp)) |
|
seq_names.extend(seqs_tmp) |
|
imglistdic = {} |
|
for seq_name in seq_names: |
|
images = list( |
|
np.sort(os.listdir(os.path.join(image_root, seq_name)))) |
|
labels = list( |
|
np.sort(os.listdir(os.path.join(label_root, seq_name)))) |
|
imglistdic[seq_name] = (images, labels) |
|
|
|
super(DAVIS2017_Train, self).__init__(image_root, |
|
label_root, |
|
imglistdic, |
|
transform, |
|
rgb, |
|
repeat_time, |
|
rand_gap, |
|
seq_len, |
|
rand_reverse, |
|
dynamic_merge, |
|
enable_prev_frame, |
|
merge_prob=merge_prob, |
|
max_obj_n=max_obj_n) |
|
|
|
|
|
class YOUTUBEVOS_Train(VOSTrain): |
|
def __init__(self, |
|
root='./datasets/YTB', |
|
year=2019, |
|
transform=None, |
|
rgb=True, |
|
rand_gap=3, |
|
seq_len=3, |
|
rand_reverse=True, |
|
dynamic_merge=True, |
|
enable_prev_frame=False, |
|
max_obj_n=10, |
|
merge_prob=0.3): |
|
root = os.path.join(root, str(year), 'train') |
|
image_root = os.path.join(root, 'JPEGImages') |
|
label_root = os.path.join(root, 'Annotations') |
|
self.seq_list_file = os.path.join(root, 'meta.json') |
|
self._check_preprocess() |
|
seq_names = list(self.ann_f.keys()) |
|
|
|
imglistdic = {} |
|
for seq_name in seq_names: |
|
data = self.ann_f[seq_name]['objects'] |
|
obj_names = list(data.keys()) |
|
images = [] |
|
labels = [] |
|
for obj_n in obj_names: |
|
if len(data[obj_n]["frames"]) < 2: |
|
print("Short object: " + seq_name + '-' + obj_n) |
|
continue |
|
images += list( |
|
map(lambda x: x + '.jpg', list(data[obj_n]["frames"]))) |
|
labels += list( |
|
map(lambda x: x + '.png', list(data[obj_n]["frames"]))) |
|
images = np.sort(np.unique(images)) |
|
labels = np.sort(np.unique(labels)) |
|
if len(images) < 2: |
|
print("Short video: " + seq_name) |
|
continue |
|
imglistdic[seq_name] = (images, labels) |
|
|
|
super(YOUTUBEVOS_Train, self).__init__(image_root, |
|
label_root, |
|
imglistdic, |
|
transform, |
|
rgb, |
|
1, |
|
rand_gap, |
|
seq_len, |
|
rand_reverse, |
|
dynamic_merge, |
|
enable_prev_frame, |
|
merge_prob=merge_prob, |
|
max_obj_n=max_obj_n) |
|
|
|
def _check_preprocess(self): |
|
if not os.path.isfile(self.seq_list_file): |
|
print('No such file: {}.'.format(self.seq_list_file)) |
|
return False |
|
else: |
|
self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos'] |
|
return True |
|
|
|
|
|
class TEST(Dataset): |
|
def __init__( |
|
self, |
|
seq_len=3, |
|
obj_num=3, |
|
transform=None, |
|
): |
|
self.seq_len = seq_len |
|
self.obj_num = obj_num |
|
self.transform = transform |
|
|
|
def __len__(self): |
|
return 3000 |
|
|
|
def __getitem__(self, idx): |
|
img = np.zeros((800, 800, 3)).astype(np.float32) |
|
label = np.ones((800, 800)).astype(np.uint8) |
|
sample = { |
|
'ref_img': img, |
|
'prev_img': img, |
|
'curr_img': [img] * (self.seq_len - 2), |
|
'ref_label': label, |
|
'prev_label': label, |
|
'curr_label': [label] * (self.seq_len - 2) |
|
} |
|
sample['meta'] = { |
|
'seq_name': 'test', |
|
'frame_num': 100, |
|
'obj_num': self.obj_num |
|
} |
|
|
|
if self.transform is not None: |
|
sample = self.transform(sample) |
|
return sample |
|
|