Spaces:
Running
Running
File size: 3,391 Bytes
560b597 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import gradio as gr
import torch
import cv2
import numpy as np
import json
from unidepth.models import UniDepthV2
import os
import matplotlib.pyplot as plt
import matplotlib
from PIL import Image
# Load model configurations and initialize model
def load_model(config_path, model_path, encoder, device):
with open(config_path) as f:
config = json.load(f)
model = UniDepthV2(config)
model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True)
model = model.to(device).eval()
return model
# Inference function
def depth_estimation(image, model_path, encoder='vits'):
try:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config_path = 'configs/config_v2_vits14.json'
# Ensure model path exists or download if needed
if not os.path.exists(model_path):
return "Model checkpoint not found. Please upload a valid model path."
model = load_model(config_path, model_path, encoder, device)
# Preprocess image
rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W
predictions = model.infer(rgb)
depth = predictions["depth"].squeeze().to('cpu').numpy()
min_depth = depth.min()
max_depth = depth.max()
depth_normalized = (depth - min_depth) / (max_depth - min_depth)
# Apply colormap
cmap = matplotlib.colormaps.get_cmap('Spectral')
depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8)
# Create a figure and axis for the colorbar
fig, ax = plt.subplots(figsize=(6, 0.4))
fig.subplots_adjust(bottom=0.5)
# Create a colorbar
norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)')
# Save the colorbar to a BytesIO object
from io import BytesIO
buf = BytesIO()
fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
buf.seek(0)
# Open the colorbar image
colorbar_img = Image.open(buf)
# Create a new image with space for the colorbar
new_height = depth_color.shape[0] + colorbar_img.size[1]
new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255))
# Paste the depth image and colorbar
new_img.paste(Image.fromarray(depth_color), (0, 0))
new_img.paste(colorbar_img, (0, depth_color.shape[0]))
return new_img
except Exception as e:
return f"Error occurred: {str(e)}"
# Gradio Interface
def main():
iface = gr.Interface(
fn=depth_estimation,
inputs=[
gr.Image(type="numpy", label="Input Image"),
gr.Textbox(value='checkpoint/latest.pth', label='Model Path'),
gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'),
],
outputs=[
gr.Image(type="pil", label="Predicted Depth")
],
title="Depth Anything V2 Metric Depth Estimation",
description="Upload an image to get its estimated depth map using Depth Anything V2.",
)
iface.launch()
if __name__ == "__main__":
main()
|