Spaces:
Running
Running
""" | |
Author: Luigi Piccinelli | |
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) | |
""" | |
import argparse | |
import json | |
import os | |
from math import ceil | |
import huggingface_hub | |
import torch.nn.functional as F | |
import torch.onnx | |
from unidepth.models.unidepthv2 import UniDepthV2 | |
from unidepth.utils.geometric import generate_rays | |
class UniDepthV2ONNX(UniDepthV2): | |
def __init__( | |
self, | |
config, | |
eps: float = 1e-6, | |
**kwargs, | |
): | |
super(UniDepthV2ONNX, self).__init__(config, eps) | |
def forward(self, rgbs): | |
H, W = rgbs.shape[-2:] | |
features, tokens = self.pixel_encoder(rgbs) | |
cls_tokens = [x.contiguous() for x in tokens] | |
features = [ | |
self.stacking_fn(features[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
tokens = [ | |
self.stacking_fn(tokens[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
global_tokens = [cls_tokens[i] for i in [-2, -1]] | |
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] | |
inputs = {} | |
inputs["image"] = rgbs | |
inputs["features"] = features | |
inputs["tokens"] = tokens | |
inputs["global_tokens"] = global_tokens | |
inputs["camera_tokens"] = camera_tokens | |
outs = self.pixel_decoder(inputs, {}) | |
predictions = F.interpolate( | |
outs["depth"], | |
size=(H, W), | |
mode="bilinear", | |
) | |
confidence = F.interpolate( | |
outs["confidence"], | |
size=(H, W), | |
mode="bilinear", | |
) | |
return outs["K"], predictions, confidence | |
class UniDepthV2wCamONNX(UniDepthV2): | |
def __init__( | |
self, | |
config, | |
eps: float = 1e-6, | |
**kwargs, | |
): | |
super(UniDepthV2wCamONNX, self).__init__(config, eps) | |
def forward(self, rgbs, K): | |
H, W = rgbs.shape[-2:] | |
features, tokens = self.pixel_encoder(rgbs) | |
cls_tokens = [x.contiguous() for x in tokens] | |
features = [ | |
self.stacking_fn(features[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
tokens = [ | |
self.stacking_fn(tokens[i:j]).contiguous() | |
for i, j in self.slices_encoder_range | |
] | |
global_tokens = [cls_tokens[i] for i in [-2, -1]] | |
camera_tokens = [cls_tokens[i] for i in [-3, -2, -1]] + [tokens[-2]] | |
inputs = {} | |
inputs["image"] = rgbs | |
inputs["features"] = features | |
inputs["tokens"] = tokens | |
inputs["global_tokens"] = global_tokens | |
inputs["camera_tokens"] = camera_tokens | |
rays, angles = generate_rays(K, (H, W)) | |
inputs["rays"] = rays | |
inputs["angles"] = angles | |
inputs["K"] = K | |
outs = self.pixel_decoder(inputs, {}) | |
predictions = F.interpolate( | |
outs["depth"], | |
size=(H, W), | |
mode="bilinear", | |
) | |
predictions_normalized = F.interpolate( | |
outs["depth_ssi"], | |
size=(H, W), | |
mode="bilinear", | |
) | |
confidence = F.interpolate( | |
outs["confidence"], | |
size=(H, W), | |
mode="bilinear", | |
) | |
return outs["K"], predictions, predictions_normalized, confidence | |
def export(model, path, shape=(462, 616), with_camera=False): | |
model.eval() | |
image = torch.rand(1, 3, *shape) | |
dynamic_axes_in = {"image": {0: "batch"}} | |
inputs = [image] | |
if with_camera: | |
K = torch.rand(1, 3, 3) | |
inputs.append(K) | |
dynamic_axes_in["K"] = {0: "batch"} | |
dynamic_axes_out = { | |
"out_K": {0: "batch"}, | |
"depth": {0: "batch"}, | |
"confidence": {0: "batch"}, | |
} | |
torch.onnx.export( | |
model, | |
tuple(inputs), | |
path, | |
input_names=list(dynamic_axes_in.keys()), | |
output_names=list(dynamic_axes_out.keys()), | |
opset_version=14, | |
dynamic_axes={**dynamic_axes_in, **dynamic_axes_out}, | |
) | |
print(f"Model exported to {path}") | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Export UniDepthV2 model to ONNX") | |
parser.add_argument( | |
"--version", type=str, default="v2", choices=["v2"], help="UniDepth version" | |
) | |
parser.add_argument( | |
"--backbone", | |
type=str, | |
default="vitl14", | |
choices=["vits14", "vitl14"], | |
help="Backbone model", | |
) | |
parser.add_argument( | |
"--shape", | |
type=int, | |
nargs=2, | |
default=(462, 616), | |
help="Input shape. No dyamic shape supported!", | |
) | |
parser.add_argument( | |
"--output-path", type=str, default="unidepthv2.onnx", help="Output ONNX file" | |
) | |
parser.add_argument( | |
"--with-camera", | |
action="store_true", | |
help="Export model that expects GT camera matrix at inference", | |
) | |
args = parser.parse_args() | |
version = args.version | |
backbone = args.backbone | |
shape = args.shape | |
output_path = args.output_path | |
with_camera = args.with_camera | |
# force shape to be multiple of 14 | |
shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape] | |
if list(shape) != list(shape_rounded): | |
print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}") | |
shape = shape_rounded | |
# assumes command is from root of repo | |
with open(os.path.join("configs", f"config_{version}_{backbone}.json")) as f: | |
config = json.load(f) | |
# tell DINO not to use efficient attention: not exportable | |
config["training"]["export"] = True | |
model_factory = UniDepthV2ONNX if not with_camera else UniDepthV2wCamONNX | |
model = model_factory(config) | |
path = huggingface_hub.hf_hub_download( | |
repo_id=f"lpiccinelli/unidepth-{version}-{backbone}", | |
filename=f"pytorch_model.bin", | |
repo_type="model", | |
) | |
info = model.load_state_dict(torch.load(path), strict=False) | |
print(f"UniDepth_{version}_{backbone} is loaded with:") | |
print(f"\t missing keys: {info.missing_keys}") | |
print(f"\t additional keys: {info.unexpected_keys}") | |
export( | |
model=model, | |
path=os.path.join(os.environ["TMPDIR"], output_path), | |
shape=shape, | |
with_camera=with_camera, | |
) | |