File size: 4,519 Bytes
33a65b5
 
 
 
89a1e10
72cd992
89a1e10
33a65b5
72cd992
 
ecb2ac2
 
33a65b5
72cd992
33a65b5
72cd992
33a65b5
72cd992
 
 
 
 
 
 
 
33a65b5
 
 
 
72cd992
89a1e10
72cd992
 
 
33a65b5
 
72cd992
33a65b5
72cd992
 
33a65b5
 
 
 
 
 
 
 
 
72cd992
 
 
 
 
 
33a65b5
72cd992
 
 
 
33a65b5
 
72cd992
 
 
 
 
 
 
 
 
33a65b5
 
ecb2ac2
 
89a1e10
72cd992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33a65b5
72cd992
 
 
 
 
 
 
 
33a65b5
72cd992
 
33a65b5
 
 
 
89a1e10
72cd992
33a65b5
ecb2ac2
33a65b5
 
 
89a1e10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
from PIL import Image
import numpy as np
from distillanydepth.modeling.archs.dam.dam import DepthAnything
from distillanydepth.utils.image_util import chw2hwc, colorize_depth_maps
from distillanydepth.midas.transforms import Resize, NormalizeImage, PrepareForNet
from torchvision.transforms import Compose
import cv2
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import spaces

# Helper function to load model from Hugging Face
def load_model_by_name(arch_name, checkpoint_path, device):
    model = None
    if arch_name == 'depthanything':
        # 使用 safetensors 加载模型权重
        model_weights = load_file(checkpoint_path)  # safetensors 加载方式
        
        # 初始化模型
        model = DepthAnything(checkpoint_path=None).to(device)
        model.load_state_dict(model_weights)  # 将加载的权重应用到模型

        model = model.to(device)  # 确保模型在正确的设备上
    else:
        raise NotImplementedError(f"Unknown architecture: {arch_name}")
    return model

# Image processing function
def process_image(image, model, device):
    if model is None:
        return None
    
    # Preprocess the image
    image_np = np.array(image)[..., ::-1] / 255
    
    transform = Compose([
        Resize(756, 756, resize_target=False, keep_aspect_ratio=True, ensure_multiple_of=14, resize_method='lower_bound', image_interpolation_method=cv2.INTER_CUBIC),
        NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        PrepareForNet()
    ])
    
    image_tensor = transform({'image': image_np})['image']
    image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device)
    
    with torch.no_grad():  # Disable autograd since we don't need gradients on CPU
        pred_disp, _ = model(image_tensor)

    # Ensure the depth map is in the correct shape before colorization
    pred_disp_np = pred_disp.cpu().detach().numpy()[0, 0, :, :]  # Remove extra singleton dimensions
    
    # Normalize depth map
    pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min())
    
    # Colorize depth map
    cmap = "Spectral_r"
    depth_colored = colorize_depth_maps(pred_disp[None, ..., None], 0, 1, cmap=cmap).squeeze()  # Ensure correct dimension
    
    # Convert to uint8 for image display
    depth_colored = (depth_colored * 255).astype(np.uint8)
    
    # Convert to HWC format (height, width, channels)
    depth_colored_hwc = chw2hwc(depth_colored)
    
    # Resize to match the original image dimensions (height, width)
    h, w = image_np.shape[:2]
    depth_colored_hwc = cv2.resize(depth_colored_hwc, (w, h), cv2.INTER_LINEAR)
    
    # Convert to a PIL image
    depth_image = Image.fromarray(depth_colored_hwc)
    return depth_image

# Gradio interface function with GPU support
@spaces.GPU
def gradio_interface(image):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_kwargs = dict(
        vitb=dict(
            encoder='vitb',
            features=128,
            out_channels=[96, 192, 384, 768],
        ),
        vitl=dict(
            encoder="vitl", 
            features=256, 
            out_channels=[256, 512, 1024, 1024], 
            use_bn=False, 
            use_clstoken=False, 
            max_depth=150.0, 
            mode='disparity',
            pretrain_type='dinov2',
            del_mask_token=False
        )
    )
    # Load model
    model = DepthAnything(**model_kwargs['vitl']).to(device)
    checkpoint_path = hf_hub_download(repo_id=f"xingyang1/Distill-Any-Depth", filename=f"large/model.safetensors", repo_type="model")

    # 使用 safetensors 加载模型权重
    model_weights = load_file(checkpoint_path)  # safetensors 加载方式
    model.load_state_dict(model_weights)
    model = model.to(device)  # 确保模型在正确的设备上
    
    if model is None:
        return None
    
    # Process image and return output
    depth_image = process_image(image, model, device)
    return depth_image

# Create Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="pil"),  # Only image input, no mode selection
    outputs=gr.Image(type="pil"),  # Only depth image output, no debug info
    title="Depth Estimation Demo",
    description="Upload an image to see the depth estimation results. Our model is running on GPU for faster processing."
)

# Launch the Gradio interface
iface.launch()