sna89's picture
update app.py file
dbf9317
from transformers import (
MaskFormerImageProcessor,
AutoImageProcessor,
MaskFormerForInstanceSegmentation,
)
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import gradio as gr
import numpy as np
processor = AutoImageProcessor.from_pretrained("facebook/maskformer-swin-base-coco")
model = MaskFormerForInstanceSegmentation.from_pretrained(
"sna89/segmentation_model"
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
def segment_image(img):
img_pt = processor(img, return_tensors="pt")
img_pt = img_pt.to(device)
with torch.no_grad():
outputs = model(**img_pt)
predicted_semantic_map = processor.post_process_semantic_segmentation(
outputs, target_sizes=[img.size[::-1]]
)[0]
fig, ax = plt.subplots(figsize=(5, 5))
plt.axis('off')
plt.imshow(predicted_semantic_map.to("cpu"))
fig.canvas.draw() # Render the figure
image_array = np.array(fig.canvas.renderer.buffer_rgba())
return image_array
# return predicted_semantic_map.to("cpu").numpy()
demo = gr.Interface(
fn=segment_image,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Semantic segmentation for sidewalk dataset",
examples=[["image.jpg"], ["image (1).jpg"]],
live=True
)
demo.launch(share=True)