File size: 3,391 Bytes
560b597
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
import torch
import cv2
import numpy as np
import json
from unidepth.models import UniDepthV2
import os
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image


# Load model configurations and initialize model
def load_model(config_path, model_path, encoder, device):
    with open(config_path) as f:
        config = json.load(f)

    model = UniDepthV2(config)
    model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True)
    model = model.to(device).eval()

    return model

# Inference function
def depth_estimation(image, model_path, encoder='vits'):
    try:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        config_path = 'configs/config_v2_vits14.json'
        
        # Ensure model path exists or download if needed
        if not os.path.exists(model_path):
            return "Model checkpoint not found. Please upload a valid model path."
        
        model = load_model(config_path, model_path, encoder, device)

        # Preprocess image
        rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device)  # C, H, W
        predictions = model.infer(rgb)
        depth = predictions["depth"].squeeze().to('cpu').numpy()

        min_depth = depth.min()
        max_depth = depth.max()

        depth_normalized = (depth - min_depth) / (max_depth - min_depth)

        # Apply colormap
        cmap = matplotlib.colormaps.get_cmap('Spectral')
        depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8)

        # Create a figure and axis for the colorbar
        fig, ax = plt.subplots(figsize=(6, 0.4))
        fig.subplots_adjust(bottom=0.5)

        # Create a colorbar
        norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)')

        # Save the colorbar to a BytesIO object
        from io import BytesIO
        buf = BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
        plt.close(fig)
        buf.seek(0)

        # Open the colorbar image
        colorbar_img = Image.open(buf)

        # Create a new image with space for the colorbar
        new_height = depth_color.shape[0] + colorbar_img.size[1]
        new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255))

        # Paste the depth image and colorbar
        new_img.paste(Image.fromarray(depth_color), (0, 0))
        new_img.paste(colorbar_img, (0, depth_color.shape[0]))

        return new_img
    

    except Exception as e:
        return f"Error occurred: {str(e)}"

# Gradio Interface
def main():
    iface = gr.Interface(
        fn=depth_estimation,
        inputs=[
            gr.Image(type="numpy", label="Input Image"),
            gr.Textbox(value='checkpoint/latest.pth', label='Model Path'),
            gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
        ],
        outputs=[
            gr.Image(type="pil", label="Predicted Depth")
        ],
        title="Depth Anything V2 Metric Depth Estimation",
        description="Upload an image to get its estimated depth map using Depth Anything V2.",
    )

    iface.launch()


if __name__ == "__main__":
    main()