Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
from distillanydepth.modeling.archs.dam.dam import DepthAnything | |
from distillanydepth.utils.image_util import colorize_depth_maps | |
from distillanydepth.midas.transforms import Resize, NormalizeImage, PrepareForNet | |
from torchvision.transforms import Compose | |
import os | |
# Helper function to load model (same as your original code) | |
def load_model_by_name(arch_name, checkpoint_path, device): | |
if arch_name == 'depthanything': | |
if '.safetensors' in checkpoint_path: | |
model = DepthAnything.from_pretrained(os.path.dirname(checkpoint_path)).to(device) | |
else: | |
raise NotImplementedError("Model architecture not implemented.") | |
else: | |
raise NotImplementedError(f"Unknown architecture: {arch_name}") | |
return model | |
# Image processing function (same as your original code, modified for Gradio) | |
def process_image(image, model, device): | |
# Preprocess the image | |
image_np = np.array(image)[..., ::-1] / 255 | |
transform = Compose([ | |
Resize(512, 512, resize_target=None, keep_aspect_ratio=False, ensure_multiple_of=32, image_interpolation_method=cv2.INTER_CUBIC), | |
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
PrepareForNet() | |
]) | |
image_tensor = transform({'image': image_np})['image'] | |
image_tensor = torch.from_numpy(image_tensor).unsqueeze(0).to(device) | |
with torch.no_grad(): # Disable autograd since we don't need gradients on CPU | |
pred_disp, _ = model(image_tensor) | |
pred_disp_np = pred_disp.cpu().detach().numpy()[0, :, :, :].transpose(1, 2, 0) | |
pred_disp = (pred_disp_np - pred_disp_np.min()) / (pred_disp_np.max() - pred_disp_np.min()) | |
# Colorize depth map | |
cmap = "Spectral_r" # Default colormap for relative depth | |
depth_colored = colorize_depth_maps(pred_disp[None, ...], 0, 1, cmap=cmap).squeeze() | |
depth_colored = (depth_colored * 255).astype(np.uint8) | |
depth_image = Image.fromarray(depth_colored) | |
return depth_image | |
# Gradio interface function | |
def gradio_interface(image): | |
# Set device to CPU explicitly | |
device = torch.device("cpu") # Force using CPU | |
model = load_model_by_name("depthanything", "your_checkpoint_path_here", device) | |
# Process image and return output | |
return process_image(image, model, device) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Image(type="pil"), # Only image input, no mode selection | |
outputs=gr.Image(type="pil"), | |
title="Depth Estimation Demo", | |
description="Upload an image to see the depth estimation results." | |
) | |
# Launch the Gradio interface | |
iface.launch() |