""" Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import importlib from copy import deepcopy from math import ceil import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms.functional as TF from einops import rearrange from huggingface_hub import PyTorchModelHubMixin from unidepth.models.unidepthv1.decoder import Decoder from unidepth.utils.constants import (IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD) from unidepth.utils.distributed import is_main_process from unidepth.utils.geometric import (generate_rays, spherical_zbuffer_to_euclidean) from unidepth.utils.misc import get_params MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"} # inference helpers def _paddings(image_shape, network_shape): cur_h, cur_w = image_shape h, w = network_shape pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 return pad_left, pad_right, pad_top, pad_bottom def _shapes(image_shape, network_shape): h, w = image_shape input_ratio = w / h output_ratio = network_shape[1] / network_shape[0] if output_ratio > input_ratio: ratio = network_shape[0] / h elif output_ratio <= input_ratio: ratio = network_shape[1] / w return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes): (pad_left, pad_right, pad_top, pad_bottom) = pads rgbs = F.interpolate( rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True ) rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant") if intrinsics is not None: intrinsics = intrinsics.clone() intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top return rgbs, intrinsics return rgbs, None def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes): (pad_left, pad_right, pad_top, pad_bottom) = pads # pred mean, trim paddings, and upsample to input dim predictions = sum( [ F.interpolate( x.clone(), size=shapes, mode="bilinear", align_corners=False, antialias=True, ) for x in predictions ] ) / len(predictions) predictions = predictions[ ..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right ] predictions = F.interpolate( predictions, size=original_shapes, mode="bilinear", align_corners=False, antialias=True, ) intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio return predictions, intrinsics class UniDepthV1( nn.Module, PyTorchModelHubMixin, library_name="UniDepth", repo_url="https://github.com/lpiccinelli-eth/UniDepth", tags=["monocular-metric-depth-estimation"], ): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super().__init__() self.build(config) self.eps = eps def forward(self, inputs, image_metas=None): rgbs = inputs['image'] gt_intrinsics = inputs.get('K') H, W = rgbs.shape[-2:] # Encode encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) if "dino" in self.pixel_encoder.__class__.__name__.lower(): encoder_outputs = [ (x + y.unsqueeze(1)).contiguous() for x, y in zip(encoder_outputs, cls_tokens) ] inputs["encoder_outputs"] = encoder_outputs inputs["cls_tokens"] = cls_tokens # Get camera infos, if any if gt_intrinsics is not None: rays, angles = generate_rays( gt_intrinsics, self.image_shape, noisy=self.training ) inputs["rays"] = rays inputs["angles"] = angles inputs["K"] = gt_intrinsics self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd # Decode pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) predictions = sum( [ F.interpolate( x.clone(), size=self.image_shape, mode="bilinear", align_corners=False, antialias=True, ) for x in predictions ] ) / len(predictions) # Final 3D points backprojection pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1] # You may want to use inputs["angles"] if available? pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W) pred_angles = F.interpolate( pred_angles.clone(), size=self.image_shape, mode="bilinear", align_corners=False, antialias=True, ) points_3d = torch.cat((pred_angles, predictions), dim=1) points_3d = spherical_zbuffer_to_euclidean( points_3d.permute(0, 2, 3, 1) ).permute(0, 3, 1, 2) # Output data, use for loss computation outputs = { "angles": pred_angles, "intrinsics": pred_intrinsics, "points": points_3d, "depth": predictions.squeeze(1), } self.pixel_decoder.test_fixed_camera = False return outputs @torch.no_grad() def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False): if rgbs.ndim == 3: rgbs = rgbs.unsqueeze(0) if intrinsics is not None and intrinsics.ndim == 2: intrinsics = intrinsics.unsqueeze(0) B, _, H, W = rgbs.shape rgbs = rgbs.to(self.device) if intrinsics is not None: intrinsics = intrinsics.to(self.device) # process image and intrinsiscs (if any) to match network input (slow?) if rgbs.max() > 5 or rgbs.dtype == torch.uint8: rgbs = rgbs.to(torch.float32).div(255) if rgbs.min() >= 0.0 and rgbs.max() <= 1.0: rgbs = TF.normalize( rgbs, mean=IMAGENET_DATASET_MEAN, std=IMAGENET_DATASET_STD, ) (h, w), ratio = _shapes((H, W), self.image_shape) pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape) rgbs, gt_intrinsics = _preprocess( rgbs, intrinsics, (h, w), (pad_left, pad_right, pad_top, pad_bottom), ratio, self.image_shape, ) # run encoder encoder_outputs, cls_tokens = self.pixel_encoder(rgbs) if "dino" in self.pixel_encoder.__class__.__name__.lower(): encoder_outputs = [ (x + y.unsqueeze(1)).contiguous() for x, y in zip(encoder_outputs, cls_tokens) ] # get data for decoder and adapt to given camera inputs = {} inputs["encoder_outputs"] = encoder_outputs inputs["cls_tokens"] = cls_tokens inputs["image"] = rgbs if gt_intrinsics is not None: rays, angles = generate_rays( gt_intrinsics, self.image_shape, noisy=self.training ) inputs["rays"] = rays inputs["angles"] = angles inputs["K"] = gt_intrinsics self.pixel_decoder.test_fixed_camera = True self.pixel_decoder.skip_camera = skip_camera # decode all pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {}) # undo the reshaping and get original image size (slow) predictions, pred_intrinsics = _postprocess( predictions, pred_intrinsics, self.image_shape, (pad_left, pad_right, pad_top, pad_bottom), ratio, (H, W), ) # final 3D points backprojection intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics angles = generate_rays(intrinsics, (H, W), noisy=False)[-1] angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W) points_3d = torch.cat((angles, predictions), dim=1) points_3d = spherical_zbuffer_to_euclidean( points_3d.permute(0, 2, 3, 1) ).permute(0, 3, 1, 2) # output data outputs = { "intrinsics": pred_intrinsics, "points": points_3d, "depth": predictions[:, -1:], } self.pixel_decoder.test_fixed_camera = False self.pixel_decoder.skip_camera = False return outputs def load_pretrained(self, model_file): device = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") ) dict_model = torch.load(model_file, map_location=device) if "model" in dict_model: dict_model = dict_model["model"] new_state_dict = deepcopy( {k.replace("module.", ""): v for k, v in dict_model.items()} ) info = self.load_state_dict(new_state_dict, strict=False) if is_main_process(): print( f"Loaded from {model_file} for {self.__class__.__name__} results in:", info, ) def get_params(self, config): if hasattr(self.pixel_encoder, "get_params"): encoder_p, encoder_lr = self.pixel_encoder.get_params( config["model"]["pixel_encoder"]["lr"], config["training"]["wd"], config["training"]["ld"], ) else: encoder_p, encoder_lr = get_params( self.pixel_encoder, config["model"]["pixel_encoder"]["lr"], config["training"]["wd"], ) decoder_p, decoder_lr = get_params( self.pixel_decoder, config["training"]["lr"], config["training"]["wd"] ) return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr] @property def device(self): return next(self.parameters()).device def build(self, config): mod = importlib.import_module("unidepth.models.encoder") pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) pixel_encoder_config = { **config["training"], **config["data"], **config["model"]["pixel_encoder"], "interpolate_offset": 0.1, } pixel_encoder = pixel_encoder_factory(pixel_encoder_config) config["model"]["pixel_encoder"]["patch_size"] = ( 14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16 ) pixel_encoder_embed_dims = ( pixel_encoder.embed_dims if hasattr(pixel_encoder, "embed_dims") else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] ) config["model"]["pixel_encoder"]["embed_dim"] = getattr( pixel_encoder, "embed_dim" ) config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths self.pixel_encoder = pixel_encoder self.pixel_decoder = Decoder(config) self.image_shape = config["data"]["image_shape"]