|
from transformers import (
|
|
MaskFormerImageProcessor,
|
|
AutoImageProcessor,
|
|
MaskFormerForInstanceSegmentation,
|
|
)
|
|
import torch
|
|
from torchvision import transforms
|
|
import matplotlib.pyplot as plt
|
|
import gradio as gr
|
|
import numpy as np
|
|
|
|
processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco")
|
|
model = MaskFormerForInstanceSegmentation.from_pretrained(
|
|
"sna89/segmentation_model"
|
|
)
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
|
|
def segment_image(img):
|
|
img_pt = processor(img, return_tensors="pt")
|
|
img_pt = img_pt.to(device)
|
|
with torch.no_grad():
|
|
outputs = model(**img_pt)
|
|
|
|
predicted_semantic_map = processor.post_process_semantic_segmentation(
|
|
outputs, target_sizes=[img.size[::-1]]
|
|
)[0]
|
|
|
|
fig, ax = plt.subplots(figsize=(5, 5))
|
|
plt.axis('off')
|
|
plt.imshow(predicted_semantic_map.to("cpu"))
|
|
fig.canvas.draw()
|
|
image_array = np.array(fig.canvas.renderer.buffer_rgba())
|
|
return image_array
|
|
|
|
|
|
demo = gr.Interface(
|
|
fn=segment_image,
|
|
inputs=gr.Image(type="pil"),
|
|
outputs=gr.Image(type="pil"),
|
|
title="Semantic segmentation for sidewalk dataset",
|
|
examples=[["image.jpg"], ["image (1).jpg"]],
|
|
live=True
|
|
)
|
|
|
|
demo.launch(share=True) |