|
import os |
|
import gradio as gr |
|
import numpy as np |
|
from PIL import Image |
|
import torch |
|
from torchvision.transforms.functional import to_tensor, normalize |
|
from transformers import SegformerForSemanticSegmentation |
|
|
|
hf_token = os.environ.get("HF_TOKEN", None) |
|
device = torch.device("cpu") |
|
|
|
label2id = {"background": 0, "skin": 1, "hair": 2, "clothes": 3, "accessories": 4} |
|
id2label = {v: k for k, v in label2id.items()} |
|
colors = { |
|
"background": (40, 40, 40), |
|
"skin": (255, 178, 127), |
|
"hair": (139, 69, 19), |
|
"clothes": (100, 149, 237), |
|
"accessories": (255, 215, 0), |
|
} |
|
|
|
model = SegformerForSemanticSegmentation.from_pretrained( |
|
"neuratech-ai/person_segmentation_v3", |
|
token=hf_token, |
|
ignore_mismatched_sizes=True, |
|
num_labels=len(label2id), |
|
id2label=id2label, |
|
label2id=label2id, |
|
) |
|
model.eval() |
|
model.to(device) |
|
|
|
|
|
def preds_to_rgb(preds): |
|
preds_rgb = np.zeros((preds.shape[0], preds.shape[1], 3), dtype=np.uint8) |
|
|
|
for class_name, class_id in label2id.items(): |
|
preds_rgb[preds == class_id] = colors[class_name] |
|
|
|
return preds_rgb |
|
|
|
|
|
def query_image(img): |
|
if img is None: |
|
return None |
|
|
|
img = Image.fromarray(img) |
|
scale = 1024 / max(img.size) |
|
img = img.resize( |
|
(int(img.size[0] * scale), int(img.size[1] * scale)), Image.LANCZOS |
|
) |
|
|
|
img = normalize( |
|
to_tensor(img), |
|
mean=(0.485, 0.456, 0.406), |
|
std=(0.229, 0.224, 0.225), |
|
) |
|
|
|
with torch.no_grad(): |
|
outputs = model(img.unsqueeze(0)) |
|
|
|
preds = outputs.logits.cpu() |
|
w, h = preds.shape[-2:] |
|
preds = torch.nn.functional.interpolate( |
|
preds, size=(w * 4, h * 4), mode="bilinear", align_corners=False |
|
) |
|
results = torch.argmax(preds, dim=1).numpy()[0] |
|
results = preds_to_rgb(results) |
|
return Image.fromarray(results) |
|
|
|
|
|
demo = gr.Interface( |
|
query_image, |
|
inputs=[gr.Image()], |
|
outputs="image", |
|
title="neuratech-ai person segmentation v3", |
|
examples=[["example1.jpg"], ["example2.jpg"]], |
|
) |
|
|
|
demo.launch() |
|
|