import gradio as gr import numpy as np import torch from transformers import Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation import os access_token = os.getenv('HF_TOKEN') from huggingface_hub import login login(token = access_token) # Load the model from Hugging Face model_name = "gdurkin/cdl_mask2former_hi_res_v3" processor = Mask2FormerImageProcessor.from_pretrained(model_name,token = access_token) model = Mask2FormerForUniversalSegmentation.from_pretrained(model_name,token = access_token) device = torch.device('cpu') # Define the inference function def predict(img): if isinstance(img, np.ndarray): img = torch.from_numpy(img).float() if torch.is_tensor(img): input_tensor = img else: raise ValueError("Unsupported image format") if input_tensor.ndim == 3: input_tensor = input_tensor.unsqueeze(0) elif input_tensor.ndim != 4: raise ValueError("Input tensor must be 3D or 4D") input_tensor = input_tensor.permute(0, 3, 1, 2) # Ensure the tensor is in the correct shape (N, C, H, W) with torch.no_grad(): outputs = model(input_tensor.to(device)) target_sizes = [(input_tensor.shape[2], input_tensor.shape[3])] predicted_segmentation_maps = processor.post_process_semantic_segmentation(outputs, target_sizes=target_sizes) return predicted_segmentation_maps[0].cpu().numpy() # Create a Gradio interface iface = gr.Interface( fn=predict, inputs=gr.Image(type="numpy", image_mode='RGB'), outputs="numpy", live=True ) # Launch the interface iface.launch()