File size: 2,019 Bytes
b4880ac 3c3d624 f43b2e5 3c3d624 b4880ac f43b2e5 b4880ac f43b2e5 995dcec f43b2e5 0136d74 f43b2e5 cc9455d f43b2e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import os
import gradio as gr
import numpy as np
from PIL import Image
import torch
from torchvision.transforms.functional import to_tensor, normalize
from transformers import SegformerForSemanticSegmentation
hf_token = os.environ.get("HF_TOKEN", None)
device = torch.device("cpu")
label2id = {"background": 0, "skin": 1, "hair": 2, "clothes": 3, "accessories": 4}
id2label = {v: k for k, v in label2id.items()}
colors = {
"background": (40, 40, 40),
"skin": (255, 178, 127),
"hair": (139, 69, 19),
"clothes": (100, 149, 237),
"accessories": (255, 215, 0),
}
model = SegformerForSemanticSegmentation.from_pretrained(
"neuratech-ai/person_segmentation_v3",
token=hf_token,
ignore_mismatched_sizes=True,
num_labels=len(label2id),
id2label=id2label,
label2id=label2id,
)
model.eval()
model.to(device)
def preds_to_rgb(preds):
preds_rgb = np.zeros((preds.shape[0], preds.shape[1], 3), dtype=np.uint8)
for class_name, class_id in label2id.items():
preds_rgb[preds == class_id] = colors[class_name]
return preds_rgb
def query_image(img):
if img is None:
return None
img = Image.fromarray(img)
scale = 1024 / max(img.size)
img = img.resize(
(int(img.size[0] * scale), int(img.size[1] * scale)), Image.LANCZOS
)
img = normalize(
to_tensor(img),
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
)
with torch.no_grad():
outputs = model(img.unsqueeze(0))
preds = outputs.logits.cpu()
w, h = preds.shape[-2:]
preds = torch.nn.functional.interpolate(
preds, size=(w * 4, h * 4), mode="bilinear", align_corners=False
)
results = torch.argmax(preds, dim=1).numpy()[0]
results = preds_to_rgb(results)
return Image.fromarray(results)
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="neuratech-ai person segmentation v3",
examples=[["example1.jpg"], ["example2.jpg"]],
)
demo.launch()
|