|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import os |
|
import random |
|
|
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import torch.nn as nn |
|
import torch.utils.data.distributed |
|
from zoedepth.utils.easydict import EasyDict as edict |
|
from PIL import Image, ImageOps |
|
from torch.utils.data import DataLoader, Dataset |
|
from torchvision import transforms |
|
|
|
from zoedepth.utils.config import change_dataset |
|
|
|
from .ddad import get_ddad_loader |
|
from .diml_indoor_test import get_diml_indoor_loader |
|
from .diml_outdoor_test import get_diml_outdoor_loader |
|
from .diode import get_diode_loader |
|
from .hypersim import get_hypersim_loader |
|
from .ibims import get_ibims_loader |
|
from .sun_rgbd_loader import get_sunrgbd_loader |
|
from .vkitti import get_vkitti_loader |
|
from .vkitti2 import get_vkitti2_loader |
|
|
|
from .preprocess import CropParams, get_white_border, get_black_border |
|
|
|
|
|
def _is_pil_image(img): |
|
return isinstance(img, Image.Image) |
|
|
|
|
|
def _is_numpy_image(img): |
|
return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) |
|
|
|
|
|
def preprocessing_transforms(mode, **kwargs): |
|
return transforms.Compose([ |
|
ToTensor(mode=mode, **kwargs) |
|
]) |
|
|
|
|
|
class DepthDataLoader(object): |
|
def __init__(self, config, mode, device='cpu', transform=None, **kwargs): |
|
""" |
|
Data loader for depth datasets |
|
|
|
Args: |
|
config (dict): Config dictionary. Refer to utils/config.py |
|
mode (str): "train" or "online_eval" |
|
device (str, optional): Device to load the data on. Defaults to 'cpu'. |
|
transform (torchvision.transforms, optional): Transform to apply to the data. Defaults to None. |
|
""" |
|
|
|
self.config = config |
|
|
|
if config.dataset == 'ibims': |
|
self.data = get_ibims_loader(config, batch_size=1, num_workers=1) |
|
return |
|
|
|
if config.dataset == 'sunrgbd': |
|
self.data = get_sunrgbd_loader( |
|
data_dir_root=config.sunrgbd_root, batch_size=1, num_workers=1) |
|
return |
|
|
|
if config.dataset == 'diml_indoor': |
|
self.data = get_diml_indoor_loader( |
|
data_dir_root=config.diml_indoor_root, batch_size=1, num_workers=1) |
|
return |
|
|
|
if config.dataset == 'diml_outdoor': |
|
self.data = get_diml_outdoor_loader( |
|
data_dir_root=config.diml_outdoor_root, batch_size=1, num_workers=1) |
|
return |
|
|
|
if "diode" in config.dataset: |
|
self.data = get_diode_loader( |
|
config[config.dataset+"_root"], batch_size=1, num_workers=1) |
|
return |
|
|
|
if config.dataset == 'hypersim_test': |
|
self.data = get_hypersim_loader( |
|
config.hypersim_test_root, batch_size=1, num_workers=1) |
|
return |
|
|
|
if config.dataset == 'vkitti': |
|
self.data = get_vkitti_loader( |
|
config.vkitti_root, batch_size=1, num_workers=1) |
|
return |
|
|
|
if config.dataset == 'vkitti2': |
|
self.data = get_vkitti2_loader( |
|
config.vkitti2_root, batch_size=1, num_workers=1) |
|
return |
|
|
|
if config.dataset == 'ddad': |
|
self.data = get_ddad_loader(config.ddad_root, resize_shape=( |
|
352, 1216), batch_size=1, num_workers=1) |
|
return |
|
|
|
img_size = self.config.get("img_size", None) |
|
img_size = img_size if self.config.get( |
|
"do_input_resize", False) else None |
|
|
|
if transform is None: |
|
transform = preprocessing_transforms(mode, size=img_size) |
|
|
|
if mode == 'train': |
|
|
|
Dataset = DataLoadPreprocess |
|
self.training_samples = Dataset( |
|
config, mode, transform=transform, device=device) |
|
|
|
if config.distributed: |
|
self.train_sampler = torch.utils.data.distributed.DistributedSampler( |
|
self.training_samples) |
|
else: |
|
self.train_sampler = None |
|
|
|
self.data = DataLoader(self.training_samples, |
|
batch_size=config.batch_size, |
|
shuffle=(self.train_sampler is None), |
|
num_workers=config.workers, |
|
pin_memory=True, |
|
persistent_workers=True, |
|
|
|
sampler=self.train_sampler) |
|
|
|
elif mode == 'online_eval': |
|
self.testing_samples = DataLoadPreprocess( |
|
config, mode, transform=transform) |
|
if config.distributed: |
|
|
|
self.eval_sampler = None |
|
else: |
|
self.eval_sampler = None |
|
self.data = DataLoader(self.testing_samples, 1, |
|
shuffle=kwargs.get("shuffle_test", False), |
|
num_workers=1, |
|
pin_memory=False, |
|
sampler=self.eval_sampler) |
|
|
|
elif mode == 'test': |
|
self.testing_samples = DataLoadPreprocess( |
|
config, mode, transform=transform) |
|
self.data = DataLoader(self.testing_samples, |
|
1, shuffle=False, num_workers=1) |
|
|
|
else: |
|
print( |
|
'mode should be one of \'train, test, online_eval\'. Got {}'.format(mode)) |
|
|
|
|
|
def repetitive_roundrobin(*iterables): |
|
""" |
|
cycles through iterables but sample wise |
|
first yield first sample from first iterable then first sample from second iterable and so on |
|
then second sample from first iterable then second sample from second iterable and so on |
|
|
|
If one iterable is shorter than the others, it is repeated until all iterables are exhausted |
|
repetitive_roundrobin('ABC', 'D', 'EF') --> A D E B D F C D E |
|
""" |
|
|
|
iterables_ = [iter(it) for it in iterables] |
|
exhausted = [False] * len(iterables) |
|
while not all(exhausted): |
|
for i, it in enumerate(iterables_): |
|
try: |
|
yield next(it) |
|
except StopIteration: |
|
exhausted[i] = True |
|
iterables_[i] = itertools.cycle(iterables[i]) |
|
|
|
yield next(iterables_[i]) |
|
|
|
|
|
class RepetitiveRoundRobinDataLoader(object): |
|
def __init__(self, *dataloaders): |
|
self.dataloaders = dataloaders |
|
|
|
def __iter__(self): |
|
return repetitive_roundrobin(*self.dataloaders) |
|
|
|
def __len__(self): |
|
|
|
return len(self.dataloaders) * (max(len(dl) for dl in self.dataloaders) + 1) |
|
|
|
|
|
class MixedNYUKITTI(object): |
|
def __init__(self, config, mode, device='cpu', **kwargs): |
|
config = edict(config) |
|
config.workers = config.workers // 2 |
|
self.config = config |
|
nyu_conf = change_dataset(edict(config), 'nyu') |
|
kitti_conf = change_dataset(edict(config), 'kitti') |
|
|
|
|
|
self.config = config = nyu_conf |
|
img_size = self.config.get("img_size", None) |
|
img_size = img_size if self.config.get( |
|
"do_input_resize", False) else None |
|
if mode == 'train': |
|
nyu_loader = DepthDataLoader( |
|
nyu_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data |
|
kitti_loader = DepthDataLoader( |
|
kitti_conf, mode, device=device, transform=preprocessing_transforms(mode, size=img_size)).data |
|
|
|
self.data = RepetitiveRoundRobinDataLoader( |
|
nyu_loader, kitti_loader) |
|
else: |
|
self.data = DepthDataLoader(nyu_conf, mode, device=device).data |
|
|
|
|
|
def remove_leading_slash(s): |
|
if s[0] == '/' or s[0] == '\\': |
|
return s[1:] |
|
return s |
|
|
|
|
|
class CachedReader: |
|
def __init__(self, shared_dict=None): |
|
if shared_dict: |
|
self._cache = shared_dict |
|
else: |
|
self._cache = {} |
|
|
|
def open(self, fpath): |
|
im = self._cache.get(fpath, None) |
|
if im is None: |
|
im = self._cache[fpath] = Image.open(fpath) |
|
return im |
|
|
|
|
|
class ImReader: |
|
def __init__(self): |
|
pass |
|
|
|
|
|
def open(self, fpath): |
|
return Image.open(fpath) |
|
|
|
|
|
class DataLoadPreprocess(Dataset): |
|
def __init__(self, config, mode, transform=None, is_for_online_eval=False, **kwargs): |
|
self.config = config |
|
if mode == 'online_eval': |
|
with open(config.filenames_file_eval, 'r') as f: |
|
self.filenames = f.readlines() |
|
else: |
|
with open(config.filenames_file, 'r') as f: |
|
self.filenames = f.readlines() |
|
|
|
self.mode = mode |
|
self.transform = transform |
|
self.to_tensor = ToTensor(mode) |
|
self.is_for_online_eval = is_for_online_eval |
|
if config.use_shared_dict: |
|
self.reader = CachedReader(config.shared_dict) |
|
else: |
|
self.reader = ImReader() |
|
|
|
def postprocess(self, sample): |
|
return sample |
|
|
|
def __getitem__(self, idx): |
|
sample_path = self.filenames[idx] |
|
focal = float(sample_path.split()[2]) |
|
sample = {} |
|
|
|
if self.mode == 'train': |
|
if self.config.dataset == 'kitti' and self.config.use_right and random.random() > 0.5: |
|
image_path = os.path.join( |
|
self.config.data_path, remove_leading_slash(sample_path.split()[3])) |
|
depth_path = os.path.join( |
|
self.config.gt_path, remove_leading_slash(sample_path.split()[4])) |
|
else: |
|
image_path = os.path.join( |
|
self.config.data_path, remove_leading_slash(sample_path.split()[0])) |
|
depth_path = os.path.join( |
|
self.config.gt_path, remove_leading_slash(sample_path.split()[1])) |
|
|
|
image = self.reader.open(image_path) |
|
depth_gt = self.reader.open(depth_path) |
|
w, h = image.size |
|
|
|
if self.config.do_kb_crop: |
|
height = image.height |
|
width = image.width |
|
top_margin = int(height - 352) |
|
left_margin = int((width - 1216) / 2) |
|
depth_gt = depth_gt.crop( |
|
(left_margin, top_margin, left_margin + 1216, top_margin + 352)) |
|
image = image.crop( |
|
(left_margin, top_margin, left_margin + 1216, top_margin + 352)) |
|
|
|
|
|
|
|
if self.config.dataset == 'nyu' and self.config.avoid_boundary: |
|
|
|
|
|
|
|
crop_params = get_white_border(np.array(image, dtype=np.uint8)) |
|
image = image.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) |
|
depth_gt = depth_gt.crop((crop_params.left, crop_params.top, crop_params.right, crop_params.bottom)) |
|
|
|
|
|
image = np.array(image) |
|
image = np.pad(image, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right), (0, 0)), mode='reflect') |
|
image = Image.fromarray(image) |
|
|
|
depth_gt = np.array(depth_gt) |
|
depth_gt = np.pad(depth_gt, ((crop_params.top, h - crop_params.bottom), (crop_params.left, w - crop_params.right)), 'constant', constant_values=0) |
|
depth_gt = Image.fromarray(depth_gt) |
|
|
|
|
|
if self.config.do_random_rotate and (self.config.aug): |
|
random_angle = (random.random() - 0.5) * 2 * self.config.degree |
|
image = self.rotate_image(image, random_angle) |
|
depth_gt = self.rotate_image( |
|
depth_gt, random_angle, flag=Image.NEAREST) |
|
|
|
image = np.asarray(image, dtype=np.float32) / 255.0 |
|
depth_gt = np.asarray(depth_gt, dtype=np.float32) |
|
depth_gt = np.expand_dims(depth_gt, axis=2) |
|
|
|
if self.config.dataset == 'nyu': |
|
depth_gt = depth_gt / 1000.0 |
|
else: |
|
depth_gt = depth_gt / 256.0 |
|
|
|
if self.config.aug and (self.config.random_crop): |
|
image, depth_gt = self.random_crop( |
|
image, depth_gt, self.config.input_height, self.config.input_width) |
|
|
|
if self.config.aug and self.config.random_translate: |
|
|
|
image, depth_gt = self.random_translate(image, depth_gt, self.config.max_translation) |
|
|
|
image, depth_gt = self.train_preprocess(image, depth_gt) |
|
mask = np.logical_and(depth_gt > self.config.min_depth, |
|
depth_gt < self.config.max_depth).squeeze()[None, ...] |
|
sample = {'image': image, 'depth': depth_gt, 'focal': focal, |
|
'mask': mask, **sample} |
|
|
|
else: |
|
if self.mode == 'online_eval': |
|
data_path = self.config.data_path_eval |
|
else: |
|
data_path = self.config.data_path |
|
|
|
image_path = os.path.join( |
|
data_path, remove_leading_slash(sample_path.split()[0])) |
|
image = np.asarray(self.reader.open(image_path), |
|
dtype=np.float32) / 255.0 |
|
|
|
if self.mode == 'online_eval': |
|
gt_path = self.config.gt_path_eval |
|
depth_path = os.path.join( |
|
gt_path, remove_leading_slash(sample_path.split()[1])) |
|
has_valid_depth = False |
|
try: |
|
depth_gt = self.reader.open(depth_path) |
|
has_valid_depth = True |
|
except IOError: |
|
depth_gt = False |
|
|
|
|
|
if has_valid_depth: |
|
depth_gt = np.asarray(depth_gt, dtype=np.float32) |
|
depth_gt = np.expand_dims(depth_gt, axis=2) |
|
if self.config.dataset == 'nyu': |
|
depth_gt = depth_gt / 1000.0 |
|
else: |
|
depth_gt = depth_gt / 256.0 |
|
|
|
mask = np.logical_and( |
|
depth_gt >= self.config.min_depth, depth_gt <= self.config.max_depth).squeeze()[None, ...] |
|
else: |
|
mask = False |
|
|
|
if self.config.do_kb_crop: |
|
height = image.shape[0] |
|
width = image.shape[1] |
|
top_margin = int(height - 352) |
|
left_margin = int((width - 1216) / 2) |
|
image = image[top_margin:top_margin + 352, |
|
left_margin:left_margin + 1216, :] |
|
if self.mode == 'online_eval' and has_valid_depth: |
|
depth_gt = depth_gt[top_margin:top_margin + |
|
352, left_margin:left_margin + 1216, :] |
|
|
|
if self.mode == 'online_eval': |
|
sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth, |
|
'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1], |
|
'mask': mask} |
|
else: |
|
sample = {'image': image, 'focal': focal} |
|
|
|
if (self.mode == 'train') or ('has_valid_depth' in sample and sample['has_valid_depth']): |
|
mask = np.logical_and(depth_gt > self.config.min_depth, |
|
depth_gt < self.config.max_depth).squeeze()[None, ...] |
|
sample['mask'] = mask |
|
|
|
if self.transform: |
|
sample = self.transform(sample) |
|
|
|
sample = self.postprocess(sample) |
|
sample['dataset'] = self.config.dataset |
|
sample = {**sample, 'image_path': sample_path.split()[0], 'depth_path': sample_path.split()[1]} |
|
|
|
return sample |
|
|
|
def rotate_image(self, image, angle, flag=Image.BILINEAR): |
|
result = image.rotate(angle, resample=flag) |
|
return result |
|
|
|
def random_crop(self, img, depth, height, width): |
|
assert img.shape[0] >= height |
|
assert img.shape[1] >= width |
|
assert img.shape[0] == depth.shape[0] |
|
assert img.shape[1] == depth.shape[1] |
|
x = random.randint(0, img.shape[1] - width) |
|
y = random.randint(0, img.shape[0] - height) |
|
img = img[y:y + height, x:x + width, :] |
|
depth = depth[y:y + height, x:x + width, :] |
|
|
|
return img, depth |
|
|
|
def random_translate(self, img, depth, max_t=20): |
|
assert img.shape[0] == depth.shape[0] |
|
assert img.shape[1] == depth.shape[1] |
|
p = self.config.translate_prob |
|
do_translate = random.random() |
|
if do_translate > p: |
|
return img, depth |
|
x = random.randint(-max_t, max_t) |
|
y = random.randint(-max_t, max_t) |
|
M = np.float32([[1, 0, x], [0, 1, y]]) |
|
|
|
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) |
|
depth = cv2.warpAffine(depth, M, (depth.shape[1], depth.shape[0])) |
|
depth = depth.squeeze()[..., None] |
|
|
|
return img, depth |
|
|
|
def train_preprocess(self, image, depth_gt): |
|
if self.config.aug: |
|
|
|
do_flip = random.random() |
|
if do_flip > 0.5: |
|
image = (image[:, ::-1, :]).copy() |
|
depth_gt = (depth_gt[:, ::-1, :]).copy() |
|
|
|
|
|
do_augment = random.random() |
|
if do_augment > 0.5: |
|
image = self.augment_image(image) |
|
|
|
return image, depth_gt |
|
|
|
def augment_image(self, image): |
|
|
|
gamma = random.uniform(0.9, 1.1) |
|
image_aug = image ** gamma |
|
|
|
|
|
if self.config.dataset == 'nyu': |
|
brightness = random.uniform(0.75, 1.25) |
|
else: |
|
brightness = random.uniform(0.9, 1.1) |
|
image_aug = image_aug * brightness |
|
|
|
|
|
colors = np.random.uniform(0.9, 1.1, size=3) |
|
white = np.ones((image.shape[0], image.shape[1])) |
|
color_image = np.stack([white * colors[i] for i in range(3)], axis=2) |
|
image_aug *= color_image |
|
image_aug = np.clip(image_aug, 0, 1) |
|
|
|
return image_aug |
|
|
|
def __len__(self): |
|
return len(self.filenames) |
|
|
|
|
|
class ToTensor(object): |
|
def __init__(self, mode, do_normalize=False, size=None): |
|
self.mode = mode |
|
self.normalize = transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) if do_normalize else nn.Identity() |
|
self.size = size |
|
if size is not None: |
|
self.resize = transforms.Resize(size=size) |
|
else: |
|
self.resize = nn.Identity() |
|
|
|
def __call__(self, sample): |
|
image, focal = sample['image'], sample['focal'] |
|
image = self.to_tensor(image) |
|
image = self.normalize(image) |
|
image = self.resize(image) |
|
|
|
if self.mode == 'test': |
|
return {'image': image, 'focal': focal} |
|
|
|
depth = sample['depth'] |
|
if self.mode == 'train': |
|
depth = self.to_tensor(depth) |
|
return {**sample, 'image': image, 'depth': depth, 'focal': focal} |
|
else: |
|
has_valid_depth = sample['has_valid_depth'] |
|
image = self.resize(image) |
|
return {**sample, 'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth, |
|
'image_path': sample['image_path'], 'depth_path': sample['depth_path']} |
|
|
|
def to_tensor(self, pic): |
|
if not (_is_pil_image(pic) or _is_numpy_image(pic)): |
|
raise TypeError( |
|
'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) |
|
|
|
if isinstance(pic, np.ndarray): |
|
img = torch.from_numpy(pic.transpose((2, 0, 1))) |
|
return img |
|
|
|
|
|
if pic.mode == 'I': |
|
img = torch.from_numpy(np.array(pic, np.int32, copy=False)) |
|
elif pic.mode == 'I;16': |
|
img = torch.from_numpy(np.array(pic, np.int16, copy=False)) |
|
else: |
|
img = torch.ByteTensor( |
|
torch.ByteStorage.from_buffer(pic.tobytes())) |
|
|
|
if pic.mode == 'YCbCr': |
|
nchannel = 3 |
|
elif pic.mode == 'I;16': |
|
nchannel = 1 |
|
else: |
|
nchannel = len(pic.mode) |
|
img = img.view(pic.size[1], pic.size[0], nchannel) |
|
|
|
img = img.transpose(0, 1).transpose(0, 2).contiguous() |
|
if isinstance(img, torch.ByteTensor): |
|
return img.float() |
|
else: |
|
return img |
|
|