Spaces:
Build error
Build error
from __future__ import absolute_import, division, print_function | |
import random | |
import copy | |
import io | |
import os | |
import numpy as np | |
from PIL import Image | |
import skimage.transform | |
from collections import Counter | |
import torch | |
import torch.utils.data as data | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from torchvision.transforms.functional import InterpolationMode as IMode | |
import utils | |
class ImgDset(Dataset): | |
"""Customize the data set loading function and prepare low/high resolution image data in advance. | |
Args: | |
dataroot (str): Training data set address | |
image_size (int): High resolution image size | |
upscale_factor (int): Image magnification | |
mode (str): Data set loading method, the training data set is for data enhancement, | |
and the verification data set is not for data enhancement | |
""" | |
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None: | |
super(ImgDset, self).__init__() | |
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)] | |
if mode == "train": | |
self.hr_transforms = transforms.Compose([ | |
transforms.RandomCrop(image_size), | |
transforms.RandomRotation(90), | |
transforms.RandomHorizontalFlip(0.5), | |
]) | |
else: | |
self.hr_transforms = transforms.Resize(image_size) | |
self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True) | |
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]: | |
# Read a batch of image data | |
image = Image.open(self.filenames[batch_index]) | |
# Transform image | |
hr_image = self.hr_transforms(image) | |
lr_image = self.lr_transforms(hr_image) | |
# Convert image data into Tensor stream format (PyTorch). | |
# Note: The range of input and output is between [0, 1] | |
lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False) | |
hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False) | |
return lr_tensor, hr_tensor | |
def __len__(self) -> int: | |
return len(self.filenames) | |
class PairedImages_w_nameList(Dataset): | |
''' | |
can act as supervised or un-supervised based on flists | |
''' | |
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False): | |
self.flist1 = flist1 | |
self.flist2 = flist2 | |
self.transform1 = transform1 | |
self.transform2 = transform2 | |
self.do_aug = do_aug | |
def __getitem__(self, index): | |
impath1 = self.flist1[index] | |
img1 = Image.open(impath1).convert('RGB') | |
impath2 = self.flist2[index] | |
img2 = Image.open(impath2).convert('RGB') | |
img1 = utils.image2tensor(img1, range_norm=False, half=False) | |
img2 = utils.image2tensor(img2, range_norm=False, half=False) | |
if self.transform1 is not None: | |
img1 = self.transform1(img1) | |
if self.transform2 is not None: | |
img2 = self.transform2(img2) | |
return img1, img2 | |
def __len__(self): | |
return len(self.flist1) | |
class PairedImages_w_nameList_npy(Dataset): | |
''' | |
can act as supervised or un-supervised based on flists | |
''' | |
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False): | |
self.flist1 = flist1 | |
self.flist2 = flist2 | |
self.transform1 = transform1 | |
self.transform2 = transform2 | |
self.do_aug = do_aug | |
def __getitem__(self, index): | |
impath1 = self.flist1[index] | |
img1 = np.load(impath1) | |
impath2 = self.flist2[index] | |
img2 = np.load(impath2) | |
if self.transform1 is not None: | |
img1 = self.transform1(img1) | |
if self.transform2 is not None: | |
img2 = self.transform2(img2) | |
return img1, img2 | |
def __len__(self): | |
return len(self.flist1) | |
# def call_paired(): | |
# root1='./GOPRO_3840FPS_AVG_3-21/train/blur/' | |
# root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/' | |
# flist1=glob.glob(root1+'/*/*.png') | |
# flist2=glob.glob(root2+'/*/*.png') | |
# dset = PairedImages_w_nameList(root1,root2,flist1,flist2) | |
#### KITTI depth | |
def load_velodyne_points(filename): | |
"""Load 3D point cloud from KITTI file format | |
(adapted from https://github.com/hunse/kitti) | |
""" | |
points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4) | |
points[:, 3] = 1.0 # homogeneous | |
return points | |
def read_calib_file(path): | |
"""Read KITTI calibration file | |
(from https://github.com/hunse/kitti) | |
""" | |
float_chars = set("0123456789.e+- ") | |
data = {} | |
with open(path, 'r') as f: | |
for line in f.readlines(): | |
key, value = line.split(':', 1) | |
value = value.strip() | |
data[key] = value | |
if float_chars.issuperset(value): | |
# try to cast to float array | |
try: | |
data[key] = np.array(list(map(float, value.split(' ')))) | |
except ValueError: | |
# casting error: data[key] already eq. value, so pass | |
pass | |
return data | |
def sub2ind(matrixSize, rowSub, colSub): | |
"""Convert row, col matrix subscripts to linear indices | |
""" | |
m, n = matrixSize | |
return rowSub * (n-1) + colSub - 1 | |
def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False): | |
"""Generate a depth map from velodyne data | |
""" | |
# load calibration files | |
cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt')) | |
velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt')) | |
velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis])) | |
velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) | |
# get image shape | |
im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32) | |
# compute projection matrix velodyne->image plane | |
R_cam2rect = np.eye(4) | |
R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3) | |
P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4) | |
P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) | |
# load velodyne points and remove all behind image plane (approximation) | |
# each row of the velodyne data is forward, left, up, reflectance | |
velo = load_velodyne_points(velo_filename) | |
velo = velo[velo[:, 0] >= 0, :] | |
# project the points to the camera | |
velo_pts_im = np.dot(P_velo2im, velo.T).T | |
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis] | |
if vel_depth: | |
velo_pts_im[:, 2] = velo[:, 0] | |
# check if in bounds | |
# use minus 1 to get the exact same value as KITTI matlab code | |
velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1 | |
velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1 | |
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) | |
val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0]) | |
velo_pts_im = velo_pts_im[val_inds, :] | |
# project to image | |
depth = np.zeros((im_shape[:2])) | |
depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] | |
# find the duplicate points and choose the closest depth | |
inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) | |
dupe_inds = [item for item, count in Counter(inds).items() if count > 1] | |
for dd in dupe_inds: | |
pts = np.where(inds == dd)[0] | |
x_loc = int(velo_pts_im[pts[0], 0]) | |
y_loc = int(velo_pts_im[pts[0], 1]) | |
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() | |
depth[depth < 0] = 0 | |
return depth | |
def pil_loader(path): | |
# open path as file to avoid ResourceWarning | |
# (https://github.com/python-pillow/Pillow/issues/835) | |
with open(path, 'rb') as f: | |
with Image.open(f) as img: | |
return img.convert('RGB') | |
class MonoDataset(data.Dataset): | |
"""Superclass for monocular dataloaders | |
Args: | |
data_path | |
filenames | |
height | |
width | |
frame_idxs | |
num_scales | |
is_train | |
img_ext | |
""" | |
def __init__(self, | |
data_path, | |
filenames, | |
height, | |
width, | |
frame_idxs, | |
num_scales, | |
is_train=False, | |
img_ext='.jpg'): | |
super(MonoDataset, self).__init__() | |
self.data_path = data_path | |
self.filenames = filenames | |
self.height = height | |
self.width = width | |
self.num_scales = num_scales | |
self.interp = Image.ANTIALIAS | |
self.frame_idxs = frame_idxs | |
self.is_train = is_train | |
self.img_ext = img_ext | |
self.loader = pil_loader | |
self.to_tensor = transforms.ToTensor() | |
# We need to specify augmentations differently in newer versions of torchvision. | |
# We first try the newer tuple version; if this fails we fall back to scalars | |
try: | |
self.brightness = (0.8, 1.2) | |
self.contrast = (0.8, 1.2) | |
self.saturation = (0.8, 1.2) | |
self.hue = (-0.1, 0.1) | |
transforms.ColorJitter.get_params( | |
self.brightness, self.contrast, self.saturation, self.hue) | |
except TypeError: | |
self.brightness = 0.2 | |
self.contrast = 0.2 | |
self.saturation = 0.2 | |
self.hue = 0.1 | |
self.resize = {} | |
for i in range(self.num_scales): | |
s = 2 ** i | |
self.resize[i] = transforms.Resize((self.height // s, self.width // s), | |
interpolation=self.interp) | |
self.load_depth = self.check_depth() | |
def preprocess(self, inputs, color_aug): | |
"""Resize colour images to the required scales and augment if required | |
We create the color_aug object in advance and apply the same augmentation to all | |
images in this item. This ensures that all images input to the pose network receive the | |
same augmentation. | |
""" | |
for k in list(inputs): | |
frame = inputs[k] | |
if "color" in k: | |
n, im, i = k | |
for i in range(self.num_scales): | |
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)]) | |
for k in list(inputs): | |
f = inputs[k] | |
if "color" in k: | |
n, im, i = k | |
inputs[(n, im, i)] = self.to_tensor(f) | |
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f)) | |
def __len__(self): | |
return len(self.filenames) | |
def __getitem__(self, index): | |
"""Returns a single training item from the dataset as a dictionary. | |
Values correspond to torch tensors. | |
Keys in the dictionary are either strings or tuples: | |
("color", <frame_id>, <scale>) for raw colour images, | |
("color_aug", <frame_id>, <scale>) for augmented colour images, | |
("K", scale) or ("inv_K", scale) for camera intrinsics, | |
"stereo_T" for camera extrinsics, and | |
"depth_gt" for ground truth depth maps. | |
<frame_id> is either: | |
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index', | |
or | |
"s" for the opposite image in the stereo pair. | |
<scale> is an integer representing the scale of the image relative to the fullsize image: | |
-1 images at native resolution as loaded from disk | |
0 images resized to (self.width, self.height ) | |
1 images resized to (self.width // 2, self.height // 2) | |
2 images resized to (self.width // 4, self.height // 4) | |
3 images resized to (self.width // 8, self.height // 8) | |
""" | |
inputs = {} | |
do_color_aug = self.is_train and random.random() > 0.5 | |
do_flip = self.is_train and random.random() > 0.5 | |
line = self.filenames[index].split() | |
folder = line[0] | |
if len(line) == 3: | |
frame_index = int(line[1]) | |
else: | |
frame_index = 0 | |
if len(line) == 3: | |
side = line[2] | |
else: | |
side = None | |
for i in self.frame_idxs: | |
if i == "s": | |
other_side = {"r": "l", "l": "r"}[side] | |
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip) | |
else: | |
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip) | |
# adjusting intrinsics to match each scale in the pyramid | |
for scale in range(self.num_scales): | |
K = self.K.copy() | |
K[0, :] *= self.width // (2 ** scale) | |
K[1, :] *= self.height // (2 ** scale) | |
inv_K = np.linalg.pinv(K) | |
inputs[("K", scale)] = torch.from_numpy(K) | |
inputs[("inv_K", scale)] = torch.from_numpy(inv_K) | |
if do_color_aug: | |
color_aug = transforms.ColorJitter.get_params( | |
self.brightness, self.contrast, self.saturation, self.hue) | |
else: | |
color_aug = (lambda x: x) | |
self.preprocess(inputs, color_aug) | |
for i in self.frame_idxs: | |
del inputs[("color", i, -1)] | |
del inputs[("color_aug", i, -1)] | |
if self.load_depth: | |
depth_gt = self.get_depth(folder, frame_index, side, do_flip) | |
inputs["depth_gt"] = np.expand_dims(depth_gt, 0) | |
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32)) | |
if "s" in self.frame_idxs: | |
stereo_T = np.eye(4, dtype=np.float32) | |
baseline_sign = -1 if do_flip else 1 | |
side_sign = -1 if side == "l" else 1 | |
stereo_T[0, 3] = side_sign * baseline_sign * 0.1 | |
inputs["stereo_T"] = torch.from_numpy(stereo_T) | |
return inputs | |
def get_color(self, folder, frame_index, side, do_flip): | |
raise NotImplementedError | |
def check_depth(self): | |
raise NotImplementedError | |
def get_depth(self, folder, frame_index, side, do_flip): | |
raise NotImplementedError | |
class KITTIDataset(MonoDataset): | |
"""Superclass for different types of KITTI dataset loaders | |
""" | |
def __init__(self, *args, **kwargs): | |
super(KITTIDataset, self).__init__(*args, **kwargs) | |
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size. | |
# To normalize you need to scale the first row by 1 / image_width and the second row | |
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered. | |
# If your principal point is far from the center you might need to disable the horizontal | |
# flip augmentation. | |
self.K = np.array([[0.58, 0, 0.5, 0], | |
[0, 1.92, 0.5, 0], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1]], dtype=np.float32) | |
self.full_res_shape = (1242, 375) | |
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3} | |
def check_depth(self): | |
line = self.filenames[0].split() | |
scene_name = line[0] | |
frame_index = int(line[1]) | |
velo_filename = os.path.join( | |
self.data_path, | |
scene_name, | |
"velodyne_points/data/{:010d}.bin".format(int(frame_index))) | |
return os.path.isfile(velo_filename) | |
def get_color(self, folder, frame_index, side, do_flip): | |
color = self.loader(self.get_image_path(folder, frame_index, side)) | |
if do_flip: | |
color = color.transpose(Image.FLIP_LEFT_RIGHT) | |
return color | |
class KITTIDepthDataset(KITTIDataset): | |
"""KITTI dataset which uses the updated ground truth depth maps | |
""" | |
def __init__(self, *args, **kwargs): | |
super(KITTIDepthDataset, self).__init__(*args, **kwargs) | |
def get_image_path(self, folder, frame_index, side): | |
f_str = "{:010d}{}".format(frame_index, self.img_ext) | |
image_path = os.path.join( | |
self.data_path, | |
folder, | |
"image_0{}/data".format(self.side_map[side]), | |
f_str) | |
return image_path | |
def get_depth(self, folder, frame_index, side, do_flip): | |
f_str = "{:010d}.png".format(frame_index) | |
depth_path = os.path.join( | |
self.data_path, | |
folder, | |
"proj_depth/groundtruth/image_0{}".format(self.side_map[side]), | |
f_str) | |
depth_gt = Image.open(depth_path) | |
depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST) | |
depth_gt = np.array(depth_gt).astype(np.float32) / 256 | |
if do_flip: | |
depth_gt = np.fliplr(depth_gt) | |
return depth_gt |