import gradio as gr import torch import cv2 import numpy as np import json from unidepth.models import UniDepthV2 import os import matplotlib.pyplot as plt import matplotlib from PIL import Image # Load model configurations and initialize model def load_model(config_path, model_path, encoder, device): with open(config_path) as f: config = json.load(f) model = UniDepthV2(config) model.load_state_dict(torch.load(model_path, map_location=device)['model'], strict=True) model = model.to(device).eval() return model # Inference function def depth_estimation(image, model_path, encoder='vits'): try: device = 'cuda' if torch.cuda.is_available() else 'cpu' config_path = 'configs/config_v2_vits14.json' # Ensure model path exists or download if needed if not os.path.exists(model_path): return "Model checkpoint not found. Please upload a valid model path." model = load_model(config_path, model_path, encoder, device) # Preprocess image rgb = torch.from_numpy(np.array(image)).permute(2, 0, 1).to(device) # C, H, W predictions = model.infer(rgb) depth = predictions["depth"].squeeze().to('cpu').numpy() min_depth = depth.min() max_depth = depth.max() depth_normalized = (depth - min_depth) / (max_depth - min_depth) # Apply colormap cmap = matplotlib.colormaps.get_cmap('Spectral') depth_color = (cmap(depth_normalized)[:, :, :3] * 255).astype(np.uint8) # Create a figure and axis for the colorbar fig, ax = plt.subplots(figsize=(6, 0.4)) fig.subplots_adjust(bottom=0.5) # Create a colorbar norm = matplotlib.colors.Normalize(vmin=min_depth, vmax=max_depth) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) cbar = fig.colorbar(sm, cax=ax, orientation='horizontal', label='Depth (meters)') # Save the colorbar to a BytesIO object from io import BytesIO buf = BytesIO() fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0.1) plt.close(fig) buf.seek(0) # Open the colorbar image colorbar_img = Image.open(buf) # Create a new image with space for the colorbar new_height = depth_color.shape[0] + colorbar_img.size[1] new_img = Image.new('RGB', (depth_color.shape[1], new_height), (255, 255, 255)) # Paste the depth image and colorbar new_img.paste(Image.fromarray(depth_color), (0, 0)) new_img.paste(colorbar_img, (0, depth_color.shape[0])) return new_img except Exception as e: return f"Error occurred: {str(e)}" # Gradio Interface def main(): iface = gr.Interface( fn=depth_estimation, inputs=[ gr.Image(type="numpy", label="Input Image"), gr.Textbox(value='checkpoint/latest.pth', label='Model Path'), gr.Dropdown(choices=['vits', 'vitb', 'vitl', 'vitg'], value='vits', label='Encoder'), ], outputs=[ gr.Image(type="pil", label="Predicted Depth") ], title="Metric Depth Estimation", description="Upload an image to get its estimated depth map using Depth Anything V2.", ) iface.launch() if __name__ == "__main__": main()