smhh24's picture
Upload 90 files
560b597 verified
"""
Author: Luigi Piccinelli
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
"""
import argparse
import json
import os
from math import ceil
import huggingface_hub
import torch.nn.functional as F
import torch.onnx
from unidepth.models.unidepthv2 import UniDepthV2
from unidepth.utils.geometric import generate_rays
class UniDepthV2ONNX(UniDepthV2):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super(UniDepthV2ONNX, self).__init__(config, eps)
def forward(self, rgbs):
H, W = rgbs.shape[-2:]
features, tokens = self.pixel_encoder(rgbs)
cls_tokens = [x.contiguous() for x in tokens]
features = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
tokens = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
global_tokens = [cls_tokens[i] for i in [-2, -1]]
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]]
inputs = {}
inputs["image"] = rgbs
inputs["features"] = features
inputs["tokens"] = tokens
inputs["global_tokens"] = global_tokens
inputs["camera_tokens"] = camera_tokens
outs = self.pixel_decoder(inputs, {})
predictions = F.interpolate(
outs["depth"],
size=(H, W),
mode="bilinear",
)
confidence = F.interpolate(
outs["confidence"],
size=(H, W),
mode="bilinear",
)
return outs["K"], predictions, confidence
class UniDepthV2wCamONNX(UniDepthV2):
def __init__(
self,
config,
eps: float = 1e-6,
**kwargs,
):
super(UniDepthV2wCamONNX, self).__init__(config, eps)
def forward(self, rgbs, K):
H, W = rgbs.shape[-2:]
features, tokens = self.pixel_encoder(rgbs)
cls_tokens = [x.contiguous() for x in tokens]
features = [
self.stacking_fn(features[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
tokens = [
self.stacking_fn(tokens[i:j]).contiguous()
for i, j in self.slices_encoder_range
]
global_tokens = [cls_tokens[i] for i in [-2, -1]]
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]]
inputs = {}
inputs["image"] = rgbs
inputs["features"] = features
inputs["tokens"] = tokens
inputs["global_tokens"] = global_tokens
inputs["camera_tokens"] = camera_tokens
rays, angles = generate_rays(K, (H, W))
inputs["rays"] = rays
inputs["angles"] = angles
inputs["K"] = K
outs = self.pixel_decoder(inputs, {})
predictions = F.interpolate(
outs["depth"],
size=(H, W),
mode="bilinear",
)
predictions_normalized = F.interpolate(
outs["depth_ssi"],
size=(H, W),
mode="bilinear",
)
confidence = F.interpolate(
outs["confidence"],
size=(H, W),
mode="bilinear",
)
return outs["K"], predictions, predictions_normalized, confidence
def export(model, path, shape=(462, 616), with_camera=False):
model.eval()
image = torch.rand(1, 3, *shape)
dynamic_axes_in = {"image": {0: "batch"}}
inputs = [image]
if with_camera:
K = torch.rand(1, 3, 3)
inputs.append(K)
dynamic_axes_in["K"] = {0: "batch"}
dynamic_axes_out = {
"out_K": {0: "batch"},
"depth": {0: "batch"},
"confidence": {0: "batch"},
}
torch.onnx.export(
model,
tuple(inputs),
path,
input_names=list(dynamic_axes_in.keys()),
output_names=list(dynamic_axes_out.keys()),
opset_version=14,
dynamic_axes={**dynamic_axes_in, **dynamic_axes_out},
)
print(f"Model exported to {path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Export UniDepthV2 model to ONNX")
parser.add_argument(
"--version", type=str, default="v2", choices=["v2"], help="UniDepth version"
)
parser.add_argument(
"--backbone",
type=str,
default="vitl14",
choices=["vits14", "vitl14"],
help="Backbone model",
)
parser.add_argument(
"--shape",
type=int,
nargs=2,
default=(462, 616),
help="Input shape. No dyamic shape supported!",
)
parser.add_argument(
"--output-path", type=str, default="unidepthv2.onnx", help="Output ONNX file"
)
parser.add_argument(
"--with-camera",
action="store_true",
help="Export model that expects GT camera matrix at inference",
)
args = parser.parse_args()
version = args.version
backbone = args.backbone
shape = args.shape
output_path = args.output_path
with_camera = args.with_camera
# force shape to be multiple of 14
shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape]
if list(shape) != list(shape_rounded):
print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}")
shape = shape_rounded
# assumes command is from root of repo
with open(os.path.join("configs", f"config_{version}_{backbone}.json")) as f:
config = json.load(f)
# tell DINO not to use efficient attention: not exportable
config["training"]["export"] = True
model_factory = UniDepthV2ONNX if not with_camera else UniDepthV2wCamONNX
model = model_factory(config)
path = huggingface_hub.hf_hub_download(
repo_id=f"lpiccinelli/unidepth-{version}-{backbone}",
filename=f"pytorch_model.bin",
repo_type="model",
)
info = model.load_state_dict(torch.load(path), strict=False)
print(f"UniDepth_{version}_{backbone} is loaded with:")
print(f"\t missing keys: {info.missing_keys}")
print(f"\t additional keys: {info.unexpected_keys}")
export(
model=model,
path=os.path.join(os.environ["TMPDIR"], output_path),
shape=shape,
with_camera=with_camera,
)