|
import gradio as gr |
|
from transformers import AutoImageProcessor, BeitForSemanticSegmentation |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
|
|
|
|
processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") |
|
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") |
|
|
|
def predict(image): |
|
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
|
|
|
|
seg = logits.argmax(dim=1).squeeze().cpu().numpy() |
|
|
|
|
|
unique_classes = np.unique(seg) |
|
seg_colored = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
|
|
|
for cls in unique_classes: |
|
mask = seg == cls |
|
seg_colored[mask] = np.random.randint(0, 255, size=3) |
|
|
|
return Image.fromarray(seg_colored) |
|
|
|
|
|
gr_interface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=gr.Image(type="pil"), |
|
title="Image Segmentation with BEiT", |
|
description="Upload an image to perform segmentation using the microsoft/beit-base-finetuned-ade-640-640 model." |
|
) |
|
|
|
if __name__ == "__main__": |
|
gr_interface.launch() |
|
|