File size: 4,561 Bytes
61123b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from typing import Dict
import gradio as gr
import json
import PIL.Image, PIL.ImageOps
import torch
import torchvision.transforms.functional as F
from matplotlib import cm
from matplotlib.colors import to_hex
import numpy as np

from src.models.dino import DINOSegmentationModel
from src.models.vit import ViTSegmentation
from src.models.unet import UNet
from src.utils import get_transform


device = torch.device("cpu")
model_weight1 = "weights/dino.pth"
model_weight2 = "weights/vit.pth"
model_weight3 = "weights/unet.pth"

model1 = DINOSegmentationModel()
model1.segmentation_head.load_state_dict(torch.load(model_weight1, map_location=device))
model1.eval()
model2 = ViTSegmentation()
model2.segmentation_head.load_state_dict(torch.load(model_weight2, map_location=device))
model2.eval()
model3 = UNet()
model3.load_state_dict(torch.load(model_weight3, map_location=device))
model3.eval()

mask_labels = {
    "0": "Background", "1": "Hat", "2": "Hair", "3": "Sunglasses", "4": "Upper-clothes",
    "5": "Skirt", "6": "Pants", "7": "Dress", "8": "Belt", "9": "Right-shoe",
    "10": "Left-shoe", "11": "Face", "12": "Right-leg", "13": "Left-leg", 
    "14": "Right-arm", "15": "Left-arm", "16": "Bag", "17": "Scarf"
}

color_map = cm.get_cmap('tab20', 18)
label_colors = {label: to_hex(color_map(idx / len(mask_labels))[:3]) for idx, label in enumerate(mask_labels)}
fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255


def mask_to_color(mask: np.ndarray) -> np.ndarray:
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for class_idx in range(18):
        color_mask[mask == class_idx] = fixed_colors[class_idx]
    return color_mask


def segment_image(image, model_name: str) -> PIL.Image:
    if model_name == "DINO":
        model = model1
    elif model_name == "ViT":
        model = model2
    else:
        model = model3

    original_width, original_height = image.size
    transform = get_transform(model.mean, model.std)
    input_tensor = transform(image).unsqueeze(0)

    with torch.no_grad():
        mask = model(input_tensor)
    mask = torch.argmax(mask.squeeze(), dim=0).cpu().numpy()

    mask_image = mask_to_color(mask)

    mask_image = PIL.Image.fromarray(mask_image)
    mask_aspect_ratio = mask_image.width / mask_image.height

    new_height = original_height
    new_width = int(new_height * mask_aspect_ratio)
    mask_image = mask_image.resize((new_width, new_height), PIL.Image.Resampling.NEAREST)

    final_mask = PIL.Image.new("RGB", (original_width, original_height))
    offset = ((original_width - new_width) // 2, 0)
    final_mask.paste(mask_image, offset)

    return final_mask


def generate_legend_html_compact() -> str:
    legend_html = """

    <div style='display: flex; flex-wrap: wrap; gap: 10px; justify-content: center;'>

    """
    for idx, (label, color) in enumerate(label_colors.items()):
        legend_html += f"""

        <div style='display: flex; align-items: center; justify-content: center; 

                     padding: 5px 10px; border: 1px solid {color}; 

                     background-color: {color}; border-radius: 5px; 

                     color: white; font-size: 12px; text-align: center;'>

            {mask_labels[label]}

        </div>

        """
    legend_html += "</div>"
    return legend_html


examples = [
    ["assets/images_examples/image1.jpg"],
    ["assets/images_examples/image2.jpg"],
    ["assets/images_examples/image3.jpg"]
]


with gr.Blocks() as demo:
    gr.Markdown("## Clothes Segmentation")
    with gr.Row():
        with gr.Column():
            pic = gr.Image(label="Upload Human Image", type="pil", height=300, width=300)
            model_choice = gr.Dropdown(choices=["DINO", "ViT", "UNet"], label="Select Model", value="DINO")
            with gr.Row():
                with gr.Column(scale=1):
                    predict_btn = gr.Button("Predict")
                with gr.Column(scale=1):
                    clear_btn = gr.Button("Clear")
        
        with gr.Column():
            output = gr.Image(label="Mask", type="pil", height=300, width=300)
            legend = gr.HTML(label="Legend", value=generate_legend_html_compact())

    predict_btn.click(fn=segment_image, inputs=[pic, model_choice], outputs=output, api_name="predict")
    clear_btn.click(lambda: (None, None), outputs=[pic, output])
    gr.Examples(examples=examples, inputs=[pic])

demo.launch()