File size: 4,796 Bytes
84eee5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import glob
import os

import torch
import torch.nn.functional as F
import numpy as np

from zoedepth.utils.misc import colorize
from zoedepth.utils.config import get_config
from zoedepth.models.builder import build_model
from zoedepth.models.model_io import load_wts

from diffusers import AsymmetricAutoencoderKL, StableDiffusionInpaintPipeline

def load_ckpt(config, model, checkpoint_dir: str = "./checkpoints", ckpt_type: str = "best"):
    if hasattr(config, "checkpoint"):
        checkpoint = config.checkpoint
    elif hasattr(config, "ckpt_pattern"):
        pattern = config.ckpt_pattern
        matches = glob.glob(os.path.join(
            checkpoint_dir, f"*{pattern}*{ckpt_type}*"))
        if not (len(matches) > 0):
            raise ValueError(f"No matches found for the pattern {pattern}")

        checkpoint = matches[0]

    else:
        return model
    model = load_wts(model, checkpoint)
    print("Loaded weights from {0}".format(checkpoint))
    return model

def get_zoe_dc_model(vanilla: bool = False, ckpt_path: str = None, **kwargs):
    def ZoeD_N(midas_model_type="DPT_BEiT_L_384", vanilla=False, **kwargs):
        if midas_model_type != "DPT_BEiT_L_384":
            raise ValueError(f"Only DPT_BEiT_L_384 MiDaS model is supported for pretrained Zoe_N model, got: {midas_model_type}")

        zoedepth_config = get_config("zoedepth", "train", **kwargs)
        model = build_model(zoedepth_config)

        if vanilla:
            model.__setattr__("vanilla", True)
            return model
        else:
            model.__setattr__("vanilla", False)

        if zoedepth_config.add_depth_channel and not vanilla:
            model.core.core.pretrained.model.patch_embed.proj = torch.nn.Conv2d(
                model.core.core.pretrained.model.patch_embed.proj.in_channels+2,
                model.core.core.pretrained.model.patch_embed.proj.out_channels,
                kernel_size=model.core.core.pretrained.model.patch_embed.proj.kernel_size,
                stride=model.core.core.pretrained.model.patch_embed.proj.stride,
                padding=model.core.core.pretrained.model.patch_embed.proj.padding,
                bias=True)

        if ckpt_path is not None:
            assert os.path.exists(ckpt_path)
            zoedepth_config.__setattr__("checkpoint", ckpt_path)
        else:
            assert vanilla, "ckpt_path must be provided for non-vanilla model"

        model = load_ckpt(zoedepth_config, model)

        return model

    return ZoeD_N(vanilla=vanilla, ckpt_path=ckpt_path, **kwargs)

def infer_with_pad(zoe, x, pad_input: bool = True, fh: float = 3, fw: float = 3, upsampling_mode: str = "bicubic", padding_mode: str = "reflect", **kwargs):
    assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim())

    if pad_input:
        assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0"
        pad_h = int(np.sqrt(x.shape[2]/2) * fh)
        pad_w = int(np.sqrt(x.shape[3]/2) * fw)
        padding = [pad_w, pad_w]
        if pad_h > 0:
            padding += [pad_h, pad_h]
        
        x_rgb = x[:, :3]
        x_remaining = x[:, 3:]
        x_rgb = F.pad(x_rgb, padding, mode=padding_mode, **kwargs)
        x_remaining = F.pad(x_remaining, padding, mode="constant", value=0, **kwargs)
        x = torch.cat([x_rgb, x_remaining], dim=1)
    out = zoe(x)["metric_depth"]
    if out.shape[-2:] != x.shape[-2:]:
        out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False)
    if pad_input:
        # crop to the original size, handling the case where pad_h and pad_w is 0
        if pad_h > 0:
            out = out[:, :, pad_h:-pad_h,:]
        if pad_w > 0:
            out = out[:, :, :, pad_w:-pad_w]
    return out

@torch.no_grad()
def infer_with_zoe_dc(zoe_dc, image, sparse_depth, scaling: float = 1):
    sparse_depth_mask = (sparse_depth[None, None, ...] > 0).float()
    # the metric depth range defined during training is [1e-3, 10]
    x = torch.cat([image[None, ...], sparse_depth[None, None, ...] / (float(scaling) * 10.0), sparse_depth_mask], dim=1).to(zoe_dc.device)

    out = infer_with_pad(zoe_dc, x)
    out_flip = infer_with_pad(zoe_dc, torch.flip(x, dims=[3]))
    out = (out + torch.flip(out_flip, dims=[3])) / 2

    pred_depth = float(scaling) * out

    return torch.nn.functional.interpolate(pred_depth, image.shape[-2:], mode='bilinear', align_corners=True)[0, 0]

def get_sd_pipeline():
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        torch_dtype=torch.float16,
    )
    pipe.vae = AsymmetricAutoencoderKL.from_pretrained(
        "cross-attention/asymmetric-autoencoder-kl-x-2", 
        torch_dtype=torch.float16
    )

    return pipe