Paul Engstler
first commit
84eee5b
import glob
import os
import torch
import torch.nn.functional as F
import numpy as np
from zoedepth.utils.misc import colorize
from zoedepth.utils.config import get_config
from zoedepth.models.builder import build_model
from zoedepth.models.model_io import load_wts
from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline
def load_ckpt(config, model, checkpoint_dir: str = "./checkpoints", ckpt_type: str = "best"):
if hasattr(config, "checkpoint"):
checkpoint = config.checkpoint
elif hasattr(config, "ckpt_pattern"):
pattern = config.ckpt_pattern
matches = glob.glob(os.path.join(
checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
if not (len(matches) > 0):
raise ValueError(f"No matches found for the pattern {pattern}")
checkpoint = matches[0]
else:
return model
model = load_wts(model, checkpoint)
print("Loaded weights from {0}".format(checkpoint))
return model
def get_zoe_dc_model(vanilla: bool = False, ckpt_path: str = None, **kwargs):
def ZoeD_N(midas_model_type="DPT_BEiT_L_384", vanilla=False, **kwargs):
if midas_model_type != "DPT_BEiT_L_384":
raise ValueError(f"Only DPT_BEiT_L_384 MiDaS model is supported for pretrained Zoe_N model, got: {midas_model_type}")
zoedepth_config = get_config("zoedepth", "train", **kwargs)
model = build_model(zoedepth_config)
if vanilla:
model.__setattr__("vanilla", True)
return model
else:
model.__setattr__("vanilla", False)
if zoedepth_config.add_depth_channel and not vanilla:
model.core.core.pretrained.model.patch_embed.proj = torch.nn.Conv2d(
model.core.core.pretrained.model.patch_embed.proj.in_channels+2,
model.core.core.pretrained.model.patch_embed.proj.out_channels,
kernel_size=model.core.core.pretrained.model.patch_embed.proj.kernel_size,
stride=model.core.core.pretrained.model.patch_embed.proj.stride,
padding=model.core.core.pretrained.model.patch_embed.proj.padding,
bias=True)
if ckpt_path is not None:
assert os.path.exists(ckpt_path)
zoedepth_config.__setattr__("checkpoint", ckpt_path)
else:
assert vanilla, "ckpt_path must be provided for non-vanilla model"
model = load_ckpt(zoedepth_config, model)
return model
return ZoeD_N(vanilla=vanilla, ckpt_path=ckpt_path, **kwargs)
def infer_with_pad(zoe, x, pad_input: bool = True, fh: float = 3, fw: float = 3, upsampling_mode: str = "bicubic", padding_mode: str = "reflect", **kwargs):
assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())
if pad_input:
assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
pad_h = int(np.sqrt(x.shape[2]/2) * fh)
pad_w = int(np.sqrt(x.shape[3]/2) * fw)
padding = [pad_w, pad_w]
if pad_h > 0:
padding += [pad_h, pad_h]
x_rgb = x[:, :3]
x_remaining = x[:, 3:]
x_rgb = F.pad(x_rgb, padding, mode=padding_mode, **kwargs)
x_remaining = F.pad(x_remaining, padding, mode="constant", value=0, **kwargs)
x = torch.cat([x_rgb, x_remaining], dim=1)
out = zoe(x)["metric_depth"]
if out.shape[-2:] != x.shape[-2:]:
out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
if pad_input:
# crop to the original size, handling the case where pad_h and pad_w is 0
if pad_h > 0:
out = out[:, :, pad_h:-pad_h,:]
if pad_w > 0:
out = out[:, :, :, pad_w:-pad_w]
return out
@torch.no_grad()
def infer_with_zoe_dc(zoe_dc, image, sparse_depth, scaling: float = 1):
sparse_depth_mask = (sparse_depth[None, None, ...] > 0).float()
# the metric depth range defined during training is [1e-3, 10]
x = torch.cat([image[None, ...], sparse_depth[None, None, ...] / (float(scaling) * 10.0), sparse_depth_mask], dim=1).to(zoe_dc.device)
out = infer_with_pad(zoe_dc, x)
out_flip = infer_with_pad(zoe_dc, torch.flip(x, dims=[3]))
out = (out + torch.flip(out_flip, dims=[3])) / 2
pred_depth = float(scaling) * out
return torch.nn.functional.interpolate(pred_depth, image.shape[-2:], mode='bilinear', align_corners=True)[0, 0]
def get_sd_pipeline():
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
)
pipe.vae = AsymmetricAutoencoderKL.from_pretrained(
"cross-attention/asymmetric-autoencoder-kl-x-2",
torch_dtype=torch.float16
)
return pipe