Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torch | |
from torchvision.transforms.functional import to_tensor, normalize | |
from transformers import SegformerForSemanticSegmentation | |
hf_token = os.environ.get("HF_TOKEN", None) | |
device = torch.device("cpu") | |
label2id = {"background": 0, "skin": 1, "hair": 2, "clothes": 3, "accessories": 4} | |
id2label = {v: k for k, v in label2id.items()} | |
colors = { | |
"background": (40, 40, 40), | |
"skin": (255, 178, 127), | |
"hair": (139, 69, 19), | |
"clothes": (100, 149, 237), | |
"accessories": (255, 215, 0), | |
} | |
model = SegformerForSemanticSegmentation.from_pretrained( | |
"neuratech-ai/person_segmentation_v3", | |
token=hf_token, | |
ignore_mismatched_sizes=True, | |
num_labels=len(label2id), | |
id2label=id2label, | |
label2id=label2id, | |
) | |
model.eval() | |
model.to(device) | |
def preds_to_rgb(preds): | |
preds_rgb = np.zeros((preds.shape[0], preds.shape[1], 3), dtype=np.uint8) | |
for class_name, class_id in label2id.items(): | |
preds_rgb[preds == class_id] = colors[class_name] | |
return preds_rgb | |
def query_image(img): | |
if img is None: | |
return None | |
img = Image.fromarray(img) | |
scale = 1024 / min(img.size) | |
img = img.resize( | |
(int(img.size[0] * scale), int(img.size[1] * scale)), Image.LANCZOS | |
) | |
img = normalize( | |
to_tensor(img), | |
mean=(0.485, 0.456, 0.406), | |
std=(0.229, 0.224, 0.225), | |
) | |
with torch.no_grad(): | |
outputs = model(img.unsqueeze(0)) | |
preds = outputs.logits.cpu() | |
w, h = preds.shape[-2:] | |
preds = torch.nn.functional.interpolate( | |
preds, size=(w * 4, h * 4), mode="bilinear", align_corners=False | |
) | |
results = torch.argmax(preds, dim=1).numpy()[0] | |
results = preds_to_rgb(results) | |
return Image.fromarray(results) | |
demo = gr.Interface( | |
query_image, | |
inputs=[gr.Image()], | |
outputs="image", | |
title="neuratech-ai person segmentation v3", | |
examples=[["example1.jpg"], ["example2.jpg"], ["example3.jpg"]], | |
) | |
demo.launch() | |