""" Author: Luigi Piccinelli Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) """ import argparse import json import os from math import ceil import huggingface_hub import torch.nn.functional as F import torch.onnx from unidepth.models.unidepthv2 import UniDepthV2 from unidepth.utils.geometric import generate_rays class UniDepthV2ONNX(UniDepthV2): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super(UniDepthV2ONNX, self).__init__(config, eps) def forward(self, rgbs): H, W = rgbs.shape[-2:] features, tokens = self.pixel_encoder(rgbs) cls_tokens = [x.contiguous() for x in tokens] features = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] tokens = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] global_tokens = [cls_tokens[i] for i in [-2, -1]] camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] inputs = {} inputs["image"] = rgbs inputs["features"] = features inputs["tokens"] = tokens inputs["global_tokens"] = global_tokens inputs["camera_tokens"] = camera_tokens outs = self.pixel_decoder(inputs, {}) predictions = F.interpolate( outs["depth"], size=(H, W), mode="bilinear", ) confidence = F.interpolate( outs["confidence"], size=(H, W), mode="bilinear", ) return outs["K"], predictions, confidence class UniDepthV2wCamONNX(UniDepthV2): def __init__( self, config, eps: float = 1e-6, **kwargs, ): super(UniDepthV2wCamONNX, self).__init__(config, eps) def forward(self, rgbs, K): H, W = rgbs.shape[-2:] features, tokens = self.pixel_encoder(rgbs) cls_tokens = [x.contiguous() for x in tokens] features = [ self.stacking_fn(features[i:j]).contiguous() for i, j in self.slices_encoder_range ] tokens = [ self.stacking_fn(tokens[i:j]).contiguous() for i, j in self.slices_encoder_range ] global_tokens = [cls_tokens[i] for i in [-2, -1]] camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] inputs = {} inputs["image"] = rgbs inputs["features"] = features inputs["tokens"] = tokens inputs["global_tokens"] = global_tokens inputs["camera_tokens"] = camera_tokens rays, angles = generate_rays(K, (H, W)) inputs["rays"] = rays inputs["angles"] = angles inputs["K"] = K outs = self.pixel_decoder(inputs, {}) predictions = F.interpolate( outs["depth"], size=(H, W), mode="bilinear", ) predictions_normalized = F.interpolate( outs["depth_ssi"], size=(H, W), mode="bilinear", ) confidence = F.interpolate( outs["confidence"], size=(H, W), mode="bilinear", ) return outs["K"], predictions, predictions_normalized, confidence def export(model, path, shape=(462, 616), with_camera=False): model.eval() image = torch.rand(1, 3, *shape) dynamic_axes_in = {"image": {0: "batch"}} inputs = [image] if with_camera: K = torch.rand(1, 3, 3) inputs.append(K) dynamic_axes_in["K"] = {0: "batch"} dynamic_axes_out = { "out_K": {0: "batch"}, "depth": {0: "batch"}, "confidence": {0: "batch"}, } torch.onnx.export( model, tuple(inputs), path, input_names=list(dynamic_axes_in.keys()), output_names=list(dynamic_axes_out.keys()), opset_version=14, dynamic_axes={**dynamic_axes_in, **dynamic_axes_out}, ) print(f"Model exported to {path}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Export UniDepthV2 model to ONNX") parser.add_argument( "--version", type=str, default="v2", choices=["v2"], help="UniDepth version" ) parser.add_argument( "--backbone", type=str, default="vitl14", choices=["vits14", "vitl14"], help="Backbone model", ) parser.add_argument( "--shape", type=int, nargs=2, default=(462, 616), help="Input shape. No dyamic shape supported!", ) parser.add_argument( "--output-path", type=str, default="unidepthv2.onnx", help="Output ONNX file" ) parser.add_argument( "--with-camera", action="store_true", help="Export model that expects GT camera matrix at inference", ) args = parser.parse_args() version = args.version backbone = args.backbone shape = args.shape output_path = args.output_path with_camera = args.with_camera # force shape to be multiple of 14 shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape] if list(shape) != list(shape_rounded): print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}") shape = shape_rounded # assumes command is from root of repo with open(os.path.join("configs", f"config_{version}_{backbone}.json")) as f: config = json.load(f) # tell DINO not to use efficient attention: not exportable config["training"]["export"] = True model_factory = UniDepthV2ONNX if not with_camera else UniDepthV2wCamONNX model = model_factory(config) path = huggingface_hub.hf_hub_download( repo_id=f"lpiccinelli/unidepth-{version}-{backbone}", filename=f"pytorch_model.bin", repo_type="model", ) info = model.load_state_dict(torch.load(path), strict=False) print(f"UniDepth_{version}_{backbone} is loaded with:") print(f"\t missing keys: {info.missing_keys}") print(f"\t additional keys: {info.unexpected_keys}") export( model=model, path=os.path.join(os.environ["TMPDIR"], output_path), shape=shape, with_camera=with_camera, )