Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files
app.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import colorsys
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
import matplotlib.colors as mcolors
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from gradio.themes.utils import sizes
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
from matplotlib.patches import Patch
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
|
14 |
+
# ----------------- HELPER FUNCTIONS ----------------- #
|
15 |
+
|
16 |
+
ASSETS_DIR = os.path.join(os.path.dirname(__file__), "assets")
|
17 |
+
|
18 |
+
LABELS_TO_IDS = {
|
19 |
+
"Background": 0,
|
20 |
+
"Apparel": 1,
|
21 |
+
"Face Neck": 2,
|
22 |
+
"Hair": 3,
|
23 |
+
"Left Foot": 4,
|
24 |
+
"Left Hand": 5,
|
25 |
+
"Left Lower Arm": 6,
|
26 |
+
"Left Lower Leg": 7,
|
27 |
+
"Left Shoe": 8,
|
28 |
+
"Left Sock": 9,
|
29 |
+
"Left Upper Arm": 10,
|
30 |
+
"Left Upper Leg": 11,
|
31 |
+
"Lower Clothing": 12,
|
32 |
+
"Right Foot": 13,
|
33 |
+
"Right Hand": 14,
|
34 |
+
"Right Lower Arm": 15,
|
35 |
+
"Right Lower Leg": 16,
|
36 |
+
"Right Shoe": 17,
|
37 |
+
"Right Sock": 18,
|
38 |
+
"Right Upper Arm": 19,
|
39 |
+
"Right Upper Leg": 20,
|
40 |
+
"Torso": 21,
|
41 |
+
"Upper Clothing": 22,
|
42 |
+
"Lower Lip": 23,
|
43 |
+
"Upper Lip": 24,
|
44 |
+
"Lower Teeth": 25,
|
45 |
+
"Upper Teeth": 26,
|
46 |
+
"Tongue": 27,
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def get_palette(num_cls):
|
51 |
+
palette = [0] * (256 * 3)
|
52 |
+
palette[0:3] = [0, 0, 0]
|
53 |
+
|
54 |
+
for j in range(1, num_cls):
|
55 |
+
hue = (j - 1) / (num_cls - 1)
|
56 |
+
saturation = 1.0
|
57 |
+
value = 1.0 if j % 2 == 0 else 0.5
|
58 |
+
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
|
59 |
+
r, g, b = [int(x * 255) for x in rgb]
|
60 |
+
palette[j * 3 : j * 3 + 3] = [r, g, b]
|
61 |
+
|
62 |
+
return palette
|
63 |
+
|
64 |
+
|
65 |
+
def create_colormap(palette):
|
66 |
+
colormap = np.array(palette).reshape(-1, 3) / 255.0
|
67 |
+
return mcolors.ListedColormap(colormap)
|
68 |
+
|
69 |
+
|
70 |
+
def visualize_mask_with_overlay(img: Image.Image, mask: Image.Image, labels_to_ids: dict[str, int], alpha=0.5):
|
71 |
+
img_np = np.array(img.convert("RGB"))
|
72 |
+
mask_np = np.array(mask)
|
73 |
+
|
74 |
+
num_cls = len(labels_to_ids)
|
75 |
+
palette = get_palette(num_cls)
|
76 |
+
colormap = create_colormap(palette)
|
77 |
+
|
78 |
+
overlay = np.zeros((*mask_np.shape, 3), dtype=np.uint8)
|
79 |
+
for label, idx in labels_to_ids.items():
|
80 |
+
if idx != 0:
|
81 |
+
overlay[mask_np == idx] = np.array(colormap(idx)[:3]) * 255
|
82 |
+
|
83 |
+
blended = Image.fromarray(np.uint8(img_np * (1 - alpha) + overlay * alpha))
|
84 |
+
|
85 |
+
return blended
|
86 |
+
|
87 |
+
|
88 |
+
def create_legend_image(labels_to_ids: dict[str, int], filename="legend.png"):
|
89 |
+
num_cls = len(labels_to_ids)
|
90 |
+
palette = get_palette(num_cls)
|
91 |
+
colormap = create_colormap(palette)
|
92 |
+
|
93 |
+
fig, ax = plt.subplots(figsize=(4, 6), facecolor="white")
|
94 |
+
|
95 |
+
ax.axis("off")
|
96 |
+
|
97 |
+
legend_elements = [
|
98 |
+
Patch(facecolor=colormap(i), edgecolor="black", label=label)
|
99 |
+
for label, i in sorted(labels_to_ids.items(), key=lambda x: x[1])
|
100 |
+
]
|
101 |
+
|
102 |
+
plt.title("Legend", fontsize=16, fontweight="bold", pad=20)
|
103 |
+
|
104 |
+
legend = ax.legend(
|
105 |
+
handles=legend_elements,
|
106 |
+
loc="center",
|
107 |
+
bbox_to_anchor=(0.5, 0.5),
|
108 |
+
ncol=2,
|
109 |
+
frameon=True,
|
110 |
+
fancybox=True,
|
111 |
+
shadow=True,
|
112 |
+
fontsize=10,
|
113 |
+
title_fontsize=12,
|
114 |
+
borderpad=1,
|
115 |
+
labelspacing=1.2,
|
116 |
+
handletextpad=0.5,
|
117 |
+
handlelength=1.5,
|
118 |
+
columnspacing=1.5,
|
119 |
+
)
|
120 |
+
|
121 |
+
legend.get_frame().set_facecolor("#FAFAFA")
|
122 |
+
legend.get_frame().set_edgecolor("gray")
|
123 |
+
|
124 |
+
# Adjust layout and save
|
125 |
+
plt.tight_layout()
|
126 |
+
plt.savefig(filename, dpi=300, bbox_inches="tight")
|
127 |
+
plt.close()
|
128 |
+
|
129 |
+
|
130 |
+
# create_legend_image(LABELS_TO_IDS, filename=os.path.join(ASSETS_DIR, "legend.png"))
|
131 |
+
|
132 |
+
|
133 |
+
# ----------------- MODEL ----------------- #
|
134 |
+
|
135 |
+
URL = "https://huggingface.co/facebook/sapiens/resolve/main/sapiens_lite_host/torchscript/seg/checkpoints/sapiens_0.3b/sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2?download=true"
|
136 |
+
CHECKPOINTS_DIR = os.path.join(ASSETS_DIR, "checkpoints")
|
137 |
+
model_path = os.path.join(CHECKPOINTS_DIR, "sapiens_0.3b_goliath_best_goliath_mIoU_7673_epoch_194_torchscript.pt2")
|
138 |
+
|
139 |
+
if not os.path.exists(model_path):
|
140 |
+
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
141 |
+
import requests
|
142 |
+
|
143 |
+
response = requests.get(URL)
|
144 |
+
with open(model_path, "wb") as file:
|
145 |
+
file.write(response.content)
|
146 |
+
|
147 |
+
model = torch.jit.load(model_path)
|
148 |
+
model.eval()
|
149 |
+
|
150 |
+
|
151 |
+
@torch.no_grad()
|
152 |
+
def run_model(input_tensor, height, width):
|
153 |
+
output = model(input_tensor)
|
154 |
+
output = torch.nn.functional.interpolate(output, size=(height, width), mode="bilinear", align_corners=False)
|
155 |
+
_, preds = torch.max(output, 1)
|
156 |
+
return preds
|
157 |
+
|
158 |
+
|
159 |
+
transform_fn = transforms.Compose(
|
160 |
+
[
|
161 |
+
transforms.Resize((1024, 768)),
|
162 |
+
transforms.ToTensor(),
|
163 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
164 |
+
]
|
165 |
+
)
|
166 |
+
# ----------------- CORE FUNCTION ----------------- #
|
167 |
+
|
168 |
+
|
169 |
+
def segment(image: Image.Image) -> Image.Image:
|
170 |
+
input_tensor = transform_fn(image).unsqueeze(0)
|
171 |
+
preds = run_model(input_tensor, height=image.height, width=image.width)
|
172 |
+
mask = preds.squeeze(0).cpu().numpy()
|
173 |
+
mask_image = Image.fromarray(mask.astype("uint8"))
|
174 |
+
blended_image = visualize_mask_with_overlay(image, mask_image, LABELS_TO_IDS, alpha=0.5)
|
175 |
+
return blended_image
|
176 |
+
|
177 |
+
|
178 |
+
# ----------------- GRADIO UI ----------------- #
|
179 |
+
|
180 |
+
|
181 |
+
with open("banner.html", "r") as file:
|
182 |
+
banner = file.read()
|
183 |
+
with open("tips.html", "r") as file:
|
184 |
+
tips = file.read()
|
185 |
+
|
186 |
+
CUSTOM_CSS = """
|
187 |
+
.image-container img {
|
188 |
+
max-width: 512px;
|
189 |
+
max-height: 512px;
|
190 |
+
margin: 0 auto;
|
191 |
+
border-radius: 0px;
|
192 |
+
.gradio-container {background-color: #fafafa}
|
193 |
+
"""
|
194 |
+
|
195 |
+
with gr.Blocks(css=CUSTOM_CSS, theme=gr.themes.Monochrome(radius_size=sizes.radius_md)) as demo:
|
196 |
+
gr.HTML(banner)
|
197 |
+
gr.HTML(tips)
|
198 |
+
with gr.Row():
|
199 |
+
with gr.Column():
|
200 |
+
input_image = gr.Image(label="Input Image", type="pil", format="png")
|
201 |
+
|
202 |
+
example_model = gr.Examples(
|
203 |
+
inputs=input_image,
|
204 |
+
examples_per_page=10,
|
205 |
+
examples=[
|
206 |
+
os.path.join(ASSETS_DIR, "examples", img)
|
207 |
+
for img in os.listdir(os.path.join(ASSETS_DIR, "examples"))
|
208 |
+
],
|
209 |
+
)
|
210 |
+
with gr.Column():
|
211 |
+
result_image = gr.Image(label="Segmentation Result", format="png")
|
212 |
+
run_button = gr.Button("Run")
|
213 |
+
|
214 |
+
gr.Image(os.path.join(ASSETS_DIR, "legend.png"), label="Legend", type="filepath")
|
215 |
+
|
216 |
+
run_button.click(
|
217 |
+
fn=segment,
|
218 |
+
inputs=[input_image],
|
219 |
+
outputs=[result_image],
|
220 |
+
)
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
demo.launch(share=False)
|
banner.html
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div style="
|
2 |
+
display: flex;
|
3 |
+
flex-direction: column;
|
4 |
+
justify-content: center;
|
5 |
+
align-items: center;
|
6 |
+
text-align: center;
|
7 |
+
background: linear-gradient(45deg, #1a1a1a 0%, #333333 100%);
|
8 |
+
padding: 24px;
|
9 |
+
gap: 24px;
|
10 |
+
border-radius: 8px;
|
11 |
+
">
|
12 |
+
<div style="display: flex; gap: 8px;">
|
13 |
+
<h1 style="
|
14 |
+
font-size: 48px;
|
15 |
+
color: #fafafa;
|
16 |
+
margin: 0;
|
17 |
+
font-family: 'Trebuchet MS', 'Lucida Sans Unicode', 'Lucida Grande',
|
18 |
+
'Lucida Sans', Arial, sans-serif;
|
19 |
+
">
|
20 |
+
Sapiens 0.3B: Body-part Segmentation
|
21 |
+
</h1>
|
22 |
+
|
23 |
+
|
24 |
+
</div>
|
25 |
+
|
26 |
+
<p style="
|
27 |
+
margin: 0;
|
28 |
+
line-height: 1.6rem;
|
29 |
+
font-size: 16px;
|
30 |
+
color: #fafafa;
|
31 |
+
opacity: 0.8;
|
32 |
+
">
|
33 |
+
<a href="https://about.meta.com/realitylabs/codecavatars/sapiens/" target="_blank">Sapiens</a> is a human-centric
|
34 |
+
family of foundational models trained by Meta Reality Labs. <br />
|
35 |
+
This Space is brought to you by FASHN AI, for your convenience, to showcase the capabilities of Sapiens for
|
36 |
+
body-part Segmentation.
|
37 |
+
|
38 |
+
</p>
|
39 |
+
|
40 |
+
<div style="
|
41 |
+
display: flex;
|
42 |
+
justify-content: center;
|
43 |
+
align-items: center;
|
44 |
+
text-align: center;
|
45 |
+
">
|
46 |
+
<a href="https://fashn.ai"><img
|
47 |
+
src="https://custom-icon-badges.demolab.com/badge/FASHN_AI-333333?style=for-the-badge&logo=fashn"
|
48 |
+
alt="FASHN AI" /></a>
|
49 |
+
<a href="https://github.com/fashn-AI"><img
|
50 |
+
src="https://img.shields.io/badge/github-%23121011.svg?style=for-the-badge&logo=github&logoColor=white"
|
51 |
+
alt="Github" /></a>
|
52 |
+
<a href="https://www.linkedin.com/company/fashn">
|
53 |
+
<img src="https://img.shields.io/badge/linkedin-%230077B5.svg?style=for-the-badge&logo=linkedin&logoColor=white"
|
54 |
+
alt="LinkedIn" />
|
55 |
+
</a>
|
56 |
+
|
57 |
+
<a href="https://x.com/fashn_ai"><img
|
58 |
+
src="https://img.shields.io/badge/@fashn_ai-%23000000.svg?style=for-the-badge&logo=X&logoColor=white"
|
59 |
+
alt="X" /></a>
|
60 |
+
<a href="https://www.instagram.com/fashn.ai/"><img
|
61 |
+
src="https://img.shields.io/badge/Fashn.ai-%23E4405F.svg?style=for-the-badge&logo=Instagram&logoColor=white"
|
62 |
+
alt="Instagram" /></a>
|
63 |
+
<a href="https://discord.gg/zfqzkGBxE5">
|
64 |
+
<img src="https://img.shields.io/badge/fashn_ai-%235865F2.svg?style=for-the-badge&logo=discord&logoColor=white"
|
65 |
+
alt="Discord" />
|
66 |
+
</a>
|
67 |
+
</div>
|
68 |
+
</div>
|
gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
*.pt2
|
tips.html
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<div style="
|
2 |
+
padding: 12px;
|
3 |
+
border: 1px solid #333333;
|
4 |
+
border-radius: 8px;
|
5 |
+
text-align: center;
|
6 |
+
display: flex;
|
7 |
+
flex-direction: column;
|
8 |
+
gap: 8px;
|
9 |
+
">
|
10 |
+
<b style="font-size: 18px;"> ❣️ Tips for successful segmentations</b>
|
11 |
+
|
12 |
+
<ul style="
|
13 |
+
display: flex;
|
14 |
+
gap: 12px;
|
15 |
+
justify-content: center;
|
16 |
+
li {
|
17 |
+
margin: 0;
|
18 |
+
}
|
19 |
+
">
|
20 |
+
<li>3:4 aspect ratio</li>
|
21 |
+
<li>768x1024 (width x height) resolution</li>
|
22 |
+
|
23 |
+
</ul>
|
24 |
+
</div>
|