smhh24's picture
Upload 90 files
560b597 verified
raw
history blame
12 kB
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import importlib
from copy import deepcopy
from math import ceil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from unidepth.models.unidepthv1.decoder import Decoder
from unidepth.utils.constants import (IMAGENET_DATASET_MEAN,
IMAGENET_DATASET_STD)
from unidepth.utils.distributed import is_main_process
from unidepth.utils.geometric import (generate_rays,
spherical_zbuffer_to_euclidean)
from unidepth.utils.misc import get_params
MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"}
# inference helpers
def _paddings(image_shape, network_shape):
cur_h, cur_w = image_shape
h, w = network_shape
pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
return pad_left, pad_right, pad_top, pad_bottom
def _shapes(image_shape, network_shape):
h, w = image_shape
input_ratio = w / h
output_ratio = network_shape[1] / network_shape[0]
if output_ratio > input_ratio:
ratio = network_shape[0] / h
elif output_ratio <= input_ratio:
ratio = network_shape[1] / w
return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
(pad_left, pad_right, pad_top, pad_bottom) = pads
rgbs = F.interpolate(
rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
)
rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
if intrinsics is not None:
intrinsics = intrinsics.clone()
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
return rgbs, intrinsics
return rgbs, None
def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
(pad_left, pad_right, pad_top, pad_bottom) = pads
# pred mean, trim paddings, and upsample to input dim
predictions = sum(
[
F.interpolate(
x.clone(),
size=shapes,
mode="bilinear",
align_corners=False,
antialias=True,
)
for x in predictions
]
) / len(predictions)
predictions = predictions[
..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
]
predictions = F.interpolate(
predictions,
size=original_shapes,
mode="bilinear",
align_corners=False,
antialias=True,
)
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
return predictions, intrinsics
class UniDepthV1(
nn.Module,
PyTorchModelHubMixin,
library_name="UniDepth",
repo_url="https://github.com/lpiccinelli-eth/UniDepth",
tags=["monocular-metric-depth-estimation"],
):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super().__init__()
self.build(config)
self.eps = eps
def forward(self, inputs, image_metas=None):
rgbs = inputs['image']
gt_intrinsics = inputs.get('K')
H, W = rgbs.shape[-2:]
# Encode
encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
if "dino" in self.pixel_encoder.__class__.__name__.lower():
encoder_outputs = [
(x + y.unsqueeze(1)).contiguous()
for x, y in zip(encoder_outputs, cls_tokens)
]
inputs["encoder_outputs"] = encoder_outputs
inputs["cls_tokens"] = cls_tokens
# Get camera infos, if any
if gt_intrinsics is not None:
rays, angles = generate_rays(
gt_intrinsics, self.image_shape, noisy=self.training
)
inputs["rays"] = rays
inputs["angles"] = angles
inputs["K"] = gt_intrinsics
self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd
# Decode
pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {})
predictions = sum(
[
F.interpolate(
x.clone(),
size=self.image_shape,
mode="bilinear",
align_corners=False,
antialias=True,
)
for x in predictions
]
) / len(predictions)
# Final 3D points backprojection
pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1]
# You may want to use inputs["angles"] if available?
pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W)
pred_angles = F.interpolate(
pred_angles.clone(),
size=self.image_shape,
mode="bilinear",
align_corners=False,
antialias=True,
)
points_3d = torch.cat((pred_angles, predictions), dim=1)
points_3d = spherical_zbuffer_to_euclidean(
points_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
# Output data, use for loss computation
outputs = {
"angles": pred_angles,
"intrinsics": pred_intrinsics,
"points": points_3d,
"depth": predictions.squeeze(1),
}
self.pixel_decoder.test_fixed_camera = False
return outputs
@torch.no_grad()
def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False):
if rgbs.ndim == 3:
rgbs = rgbs.unsqueeze(0)
if intrinsics is not None and intrinsics.ndim == 2:
intrinsics = intrinsics.unsqueeze(0)
B, _, H, W = rgbs.shape
rgbs = rgbs.to(self.device)
if intrinsics is not None:
intrinsics = intrinsics.to(self.device)
# process image and intrinsiscs (if any) to match network input (slow?)
if rgbs.max() > 5 or rgbs.dtype == torch.uint8:
rgbs = rgbs.to(torch.float32).div(255)
if rgbs.min() >= 0.0 and rgbs.max() <= 1.0:
rgbs = TF.normalize(
rgbs,
mean=IMAGENET_DATASET_MEAN,
std=IMAGENET_DATASET_STD,
)
(h, w), ratio = _shapes((H, W), self.image_shape)
pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape)
rgbs, gt_intrinsics = _preprocess(
rgbs,
intrinsics,
(h, w),
(pad_left, pad_right, pad_top, pad_bottom),
ratio,
self.image_shape,
)
# run encoder
encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
if "dino" in self.pixel_encoder.__class__.__name__.lower():
encoder_outputs = [
(x + y.unsqueeze(1)).contiguous()
for x, y in zip(encoder_outputs, cls_tokens)
]
# get data for decoder and adapt to given camera
inputs = {}
inputs["encoder_outputs"] = encoder_outputs
inputs["cls_tokens"] = cls_tokens
inputs["image"] = rgbs
if gt_intrinsics is not None:
rays, angles = generate_rays(
gt_intrinsics, self.image_shape, noisy=self.training
)
inputs["rays"] = rays
inputs["angles"] = angles
inputs["K"] = gt_intrinsics
self.pixel_decoder.test_fixed_camera = True
self.pixel_decoder.skip_camera = skip_camera
# decode all
pred_intrinsics, predictions, _ = self.pixel_decoder(inputs, {})
# undo the reshaping and get original image size (slow)
predictions, pred_intrinsics = _postprocess(
predictions,
pred_intrinsics,
self.image_shape,
(pad_left, pad_right, pad_top, pad_bottom),
ratio,
(H, W),
)
# final 3D points backprojection
intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
points_3d = torch.cat((angles, predictions), dim=1)
points_3d = spherical_zbuffer_to_euclidean(
points_3d.permute(0, 2, 3, 1)
).permute(0, 3, 1, 2)
# output data
outputs = {
"intrinsics": pred_intrinsics,
"points": points_3d,
"depth": predictions[:, -1:],
}
self.pixel_decoder.test_fixed_camera = False
self.pixel_decoder.skip_camera = False
return outputs
def load_pretrained(self, model_file):
device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
dict_model = torch.load(model_file, map_location=device)
if "model" in dict_model:
dict_model = dict_model["model"]
new_state_dict = deepcopy(
{k.replace("module.", ""): v for k, v in dict_model.items()}
)
info = self.load_state_dict(new_state_dict, strict=False)
if is_main_process():
print(
f"Loaded from {model_file} for {self.__class__.__name__} results in:",
info,
)
def get_params(self, config):
if hasattr(self.pixel_encoder, "get_params"):
encoder_p, encoder_lr = self.pixel_encoder.get_params(
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
config["training"]["ld"],
)
else:
encoder_p, encoder_lr = get_params(
self.pixel_encoder,
config["model"]["pixel_encoder"]["lr"],
config["training"]["wd"],
)
decoder_p, decoder_lr = get_params(
self.pixel_decoder, config["training"]["lr"], config["training"]["wd"]
)
return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr]
@property
def device(self):
return next(self.parameters()).device
def build(self, config):
mod = importlib.import_module("unidepth.models.encoder")
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
pixel_encoder_config = {
**config["training"],
**config["data"],
**config["model"]["pixel_encoder"],
"interpolate_offset": 0.1,
}
pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
config["model"]["pixel_encoder"]["patch_size"] = (
14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
)
pixel_encoder_embed_dims = (
pixel_encoder.embed_dims
if hasattr(pixel_encoder, "embed_dims")
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
)
config["model"]["pixel_encoder"]["embed_dim"] = getattr(
pixel_encoder, "embed_dim"
)
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
self.pixel_encoder = pixel_encoder
self.pixel_decoder = Decoder(config)
self.image_shape = config["data"]["image_shape"]