englissi's picture
Update app.py
7c4faa5 verified
import gradio as gr
from transformers import AutoImageProcessor, BeitForSemanticSegmentation
from PIL import Image
import torch
import numpy as np
# ๋ชจ๋ธ๊ณผ ์ด๋ฏธ์ง€ ํ”„๋กœ์„ธ์„œ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
processor = AutoImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
def predict(image):
# ์ด๋ฏธ์ง€๋ฅผ ์ „์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค.
inputs = processor(images=image, return_tensors="pt")
# ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# ๊ฒฐ๊ณผ ์ฒ˜๋ฆฌ: ๊ฐ€์žฅ ๋†’์€ ์ ์ˆ˜๋ฅผ ๊ฐ–๋Š” ํด๋ž˜์Šค ์ฐพ๊ธฐ
seg = logits.argmax(dim=1).squeeze().cpu().numpy()
# ๊ฒฐ๊ณผ๋ฅผ ์‹œ๊ฐํ™”ํ•˜๊ธฐ ์œ„ํ•ด ์ƒ‰์ƒ์„ ํ• ๋‹นํ•ฉ๋‹ˆ๋‹ค.
unique_classes = np.unique(seg)
seg_colored = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for cls in unique_classes:
mask = seg == cls
seg_colored[mask] = np.random.randint(0, 255, size=3)
return Image.fromarray(seg_colored)
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
gr_interface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Image Segmentation with BEiT",
description="Upload an image to perform segmentation using the microsoft/beit-base-finetuned-ade-640-640 model."
)
if __name__ == "__main__":
gr_interface.launch()