Spaces:
Build error
Build error
jhj0517
commited on
Commit
·
11d7b39
1
Parent(s):
4af6ba2
initial commit
Browse files- .gitignore +3 -0
- app.py +49 -0
- models/model file will be saved here.txt +0 -0
- modules/__init__.py +0 -0
- modules/html_constants.py +84 -0
- modules/mask_utils.py +99 -0
- modules/model_downloader.py +12 -0
- modules/sam.py +65 -0
- modules/ui_utils.py +8 -0
- outputs/psd/psd file will be saved here.txt +0 -0
- requirements.txt +12 -0
- screenshot.png +0 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
models/
|
2 |
+
modules/__pycache__/
|
3 |
+
outputs/
|
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from modules import sam
|
4 |
+
from modules.ui_utils import *
|
5 |
+
from modules.html_constants import *
|
6 |
+
|
7 |
+
|
8 |
+
class App:
|
9 |
+
def __init__(self):
|
10 |
+
#download_sam_model_url()
|
11 |
+
self.app = gr.Blocks(css=CSS)
|
12 |
+
self.sam = sam.SamInference()
|
13 |
+
|
14 |
+
def launch(self):
|
15 |
+
with self.app:
|
16 |
+
with gr.Row():
|
17 |
+
gr.Markdown(MARKDOWN_NOTE, elem_id="md_pgroject")
|
18 |
+
with gr.Row().style(equal_height=True): # bug https://github.com/gradio-app/gradio/issues/3202
|
19 |
+
with gr.Column(scale=5):
|
20 |
+
img_input = gr.Image(label="Input image here")
|
21 |
+
with gr.Column(scale=5):
|
22 |
+
# Tuable Params
|
23 |
+
nb_points_per_side = gr.Number(label="points_per_side", value=32)
|
24 |
+
sld_pred_iou_thresh = gr.Slider(label="pred_iou_thresh", value=0.88, minimum=0, maximum=1)
|
25 |
+
sld_stability_score_thresh = gr.Slider(label="stability_score_thresh", value=0.95, minimum=0,
|
26 |
+
maximum=1)
|
27 |
+
nb_crop_n_layers = gr.Number(label="crop_n_layers", value=0)
|
28 |
+
nb_crop_n_points_downscale_factor = gr.Number(label="crop_n_points_downscale_factor", value=1)
|
29 |
+
nb_min_mask_region_area = gr.Number(label="min_mask_region_area", value=0)
|
30 |
+
html_param_explain = gr.HTML(PARAMS_EXPLANATION, elem_id="html_param_explain")
|
31 |
+
|
32 |
+
with gr.Row():
|
33 |
+
btn_generate = gr.Button("GENERATE", variant="primary")
|
34 |
+
with gr.Row():
|
35 |
+
gallery_output = gr.Gallery(label="Output will be shown here", show_label=True).style(grid=5,
|
36 |
+
height="auto")
|
37 |
+
btn_open_folder = gr.Button("📁\n(PSD)").style(full_width=False)
|
38 |
+
|
39 |
+
params = [nb_points_per_side, sld_pred_iou_thresh, sld_stability_score_thresh, nb_crop_n_layers,
|
40 |
+
nb_crop_n_points_downscale_factor, nb_min_mask_region_area]
|
41 |
+
btn_generate.click(fn=self.sam.generate_mask_app, inputs=[img_input] + params, outputs=gallery_output)
|
42 |
+
btn_open_folder.click(fn=lambda: open_folder("outputs\psd"), inputs=None, outputs=None)
|
43 |
+
|
44 |
+
self.app.queue(api_open=False).launch()
|
45 |
+
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
app = App()
|
49 |
+
app.launch()
|
models/model file will be saved here.txt
ADDED
File without changes
|
modules/__init__.py
ADDED
File without changes
|
modules/html_constants.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CSS = """
|
2 |
+
#md_project a {
|
3 |
+
color: black;
|
4 |
+
text-decoration: none;
|
5 |
+
}
|
6 |
+
#md_project a:hover {
|
7 |
+
text-decoration: underline;
|
8 |
+
}
|
9 |
+
"""
|
10 |
+
|
11 |
+
|
12 |
+
PROJECT_NAME = """
|
13 |
+
# [Layer-Divider-WebUI](https://github.com/jhj0517/Layer-Divider-WebUI)
|
14 |
+
"""
|
15 |
+
|
16 |
+
MARKDOWN_NOTE = """
|
17 |
+
## This space only support CPU because it's free huggingface space.
|
18 |
+
## If you want to run CUDA version , check this [repository](https://github.com/jhj0517/Layer-Divider-WebUI)
|
19 |
+
"""
|
20 |
+
|
21 |
+
PARAMS_EXPLANATION = """
|
22 |
+
<!DOCTYPE html>
|
23 |
+
<html lang="en">
|
24 |
+
<head>
|
25 |
+
<meta charset="UTF-8">
|
26 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
27 |
+
<style>
|
28 |
+
table {
|
29 |
+
border-collapse: collapse;
|
30 |
+
width: 100%;
|
31 |
+
}
|
32 |
+
th, td {
|
33 |
+
border: 1px solid #dddddd;
|
34 |
+
text-align: left;
|
35 |
+
padding: 8px;
|
36 |
+
}
|
37 |
+
th {
|
38 |
+
background-color: #f2f2f2;
|
39 |
+
}
|
40 |
+
</style>
|
41 |
+
</head>
|
42 |
+
<body>
|
43 |
+
|
44 |
+
<details>
|
45 |
+
<summary>Explanation of Each Parameter</summary>
|
46 |
+
<table>
|
47 |
+
<thead>
|
48 |
+
<tr>
|
49 |
+
<th>Parameter</th>
|
50 |
+
<th>Description</th>
|
51 |
+
</tr>
|
52 |
+
</thead>
|
53 |
+
<tbody>
|
54 |
+
<tr>
|
55 |
+
<td>points_per_side</td>
|
56 |
+
<td>The number of points to be sampled along one side of the image. The total number of points is points_per_side**2. If None, 'point_grids' must provide explicit point sampling.</td>
|
57 |
+
</tr>
|
58 |
+
<tr>
|
59 |
+
<td>pred_iou_thresh</td>
|
60 |
+
<td>A filtering threshold in [0,1], using the model's predicted mask quality.</td>
|
61 |
+
</tr>
|
62 |
+
<tr>
|
63 |
+
<td>stability_score_thresh</td>
|
64 |
+
<td>A filtering threshold in [0,1], using the stability of the mask under changes to the cutoff used to binarize the model's mask predictions.</td>
|
65 |
+
</tr>
|
66 |
+
<tr>
|
67 |
+
<td>crops_n_layers</td>
|
68 |
+
<td>If >0, mask prediction will be run again on crops of the image. Sets the number of layers to run, where each layer has 2**i_layer number of image crops.</td>
|
69 |
+
</tr>
|
70 |
+
<tr>
|
71 |
+
<td>crop_n_points_downscale_factor</td>
|
72 |
+
<td>The number of points-per-side sampled in layer n is scaled down by crop_n_points_downscale_factor**n.</td>
|
73 |
+
</tr>
|
74 |
+
<tr>
|
75 |
+
<td>min_mask_region_area</td>
|
76 |
+
<td>If >0, postprocessing will be applied to remove disconnected regions and holes in masks with area smaller than min_mask_region_area. Requires opencv.</td>
|
77 |
+
</tr>
|
78 |
+
</tbody>
|
79 |
+
</table>
|
80 |
+
</details>
|
81 |
+
|
82 |
+
</body>
|
83 |
+
</html>
|
84 |
+
"""
|
modules/mask_utils.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from pycocotools import mask as coco_mask
|
4 |
+
from pytoshop import layers
|
5 |
+
import pytoshop
|
6 |
+
from pytoshop.enums import BlendMode
|
7 |
+
from datetime import datetime
|
8 |
+
|
9 |
+
|
10 |
+
def generate_random_color():
|
11 |
+
return np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256)
|
12 |
+
|
13 |
+
|
14 |
+
def create_base_layer(image):
|
15 |
+
rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
|
16 |
+
return [rgba_image]
|
17 |
+
|
18 |
+
|
19 |
+
def create_mask_layers(image, masks):
|
20 |
+
layer_list = []
|
21 |
+
|
22 |
+
for result in masks:
|
23 |
+
rle = result['segmentation']
|
24 |
+
mask = coco_mask.decode(rle).astype(np.uint8)
|
25 |
+
rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
|
26 |
+
rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
|
27 |
+
|
28 |
+
layer_list.append(rgba_image)
|
29 |
+
|
30 |
+
return layer_list
|
31 |
+
|
32 |
+
|
33 |
+
def create_mask_gallery(image, masks):
|
34 |
+
mask_array_list = []
|
35 |
+
label_list = []
|
36 |
+
|
37 |
+
for index, result in enumerate(masks):
|
38 |
+
rle = result['segmentation']
|
39 |
+
mask = coco_mask.decode(rle).astype(np.uint8)
|
40 |
+
|
41 |
+
rgba_image = cv2.cvtColor(image, cv2.COLOR_RGB2RGBA)
|
42 |
+
rgba_image[..., 3] = cv2.bitwise_and(rgba_image[..., 3], rgba_image[..., 3], mask=mask)
|
43 |
+
|
44 |
+
mask_array_list.append(rgba_image)
|
45 |
+
label_list.append(f'Part {index}')
|
46 |
+
|
47 |
+
return [[img, label] for img, label in zip(mask_array_list, label_list)]
|
48 |
+
|
49 |
+
|
50 |
+
def create_mask_combined_images(image, masks):
|
51 |
+
final_result = np.zeros_like(image)
|
52 |
+
|
53 |
+
for result in masks:
|
54 |
+
rle = result['segmentation']
|
55 |
+
mask = coco_mask.decode(rle).astype(np.uint8)
|
56 |
+
|
57 |
+
color = generate_random_color()
|
58 |
+
colored_mask = np.zeros_like(image)
|
59 |
+
colored_mask[mask == 1] = color
|
60 |
+
|
61 |
+
final_result = cv2.addWeighted(final_result, 1, colored_mask, 0.5, 0)
|
62 |
+
|
63 |
+
combined_image = cv2.addWeighted(image, 1, final_result, 0.5, 0)
|
64 |
+
return [combined_image, "masked"]
|
65 |
+
|
66 |
+
|
67 |
+
def insert_psd_layer(psd, image_data, layer_name, blending_mode):
|
68 |
+
channel_data = [layers.ChannelImageData(image=image_data[:, :, i], compression=1) for i in range(4)]
|
69 |
+
|
70 |
+
layer_record = layers.LayerRecord(
|
71 |
+
channels={-1: channel_data[3], 0: channel_data[0], 1: channel_data[1], 2: channel_data[2]},
|
72 |
+
top=0, bottom=image_data.shape[0], left=0, right=image_data.shape[1],
|
73 |
+
blend_mode=blending_mode,
|
74 |
+
name=layer_name,
|
75 |
+
opacity=255,
|
76 |
+
)
|
77 |
+
psd.layer_and_mask_info.layer_info.layer_records.append(layer_record)
|
78 |
+
return psd
|
79 |
+
|
80 |
+
|
81 |
+
def save_psd(input_image_data, layer_data, layer_names, blending_modes):
|
82 |
+
psd_file = pytoshop.core.PsdFile(num_channels=3, height=input_image_data.shape[0], width=input_image_data.shape[1])
|
83 |
+
|
84 |
+
for index, layer in enumerate(layer_data):
|
85 |
+
psd_file = insert_psd_layer(psd_file, layer, layer_names[index], blending_modes[index])
|
86 |
+
|
87 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
88 |
+
with open(f"outputs/psd/result-{timestamp}.psd", 'wb') as output_file:
|
89 |
+
psd_file.write(output_file)
|
90 |
+
|
91 |
+
|
92 |
+
def save_psd_with_masks(image, masks):
|
93 |
+
original_layer = create_base_layer(image)
|
94 |
+
mask_layers = create_mask_layers(image, masks)
|
95 |
+
names = [f'Part {i}' for i in range(len(mask_layers))]
|
96 |
+
modes = [BlendMode.normal] * (len(mask_layers)+1)
|
97 |
+
save_psd(image, original_layer+mask_layers, ['Original_Image']+names, modes)
|
98 |
+
|
99 |
+
|
modules/model_downloader.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
AVAILABLE_MODELS = {
|
4 |
+
"ViT-H SAM model": ["sam_vit_h_4b8939.pth", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"],
|
5 |
+
"ViT-L SAM model": ["sam_vit_l_0b3195.pth", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth"],
|
6 |
+
"ViT-B SAM model": ["sam_vit_b_01ec64.pth", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"],
|
7 |
+
}
|
8 |
+
|
9 |
+
|
10 |
+
def download_sam_model_url():
|
11 |
+
torch.hub.download_url_to_file(AVAILABLE_MODELS["ViT-H SAM model"][1],
|
12 |
+
f'models/{AVAILABLE_MODELS["ViT-H SAM model"][0]}')
|
modules/sam.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
2 |
+
import os
|
3 |
+
|
4 |
+
from modules.mask_utils import *
|
5 |
+
from modules.model_downloader import *
|
6 |
+
|
7 |
+
|
8 |
+
class SamInference:
|
9 |
+
def __init__(self):
|
10 |
+
self.model = None
|
11 |
+
self.model_path = f"models/sam_vit_h_4b8939.pth"
|
12 |
+
self.device = "cuda"
|
13 |
+
self.mask_generator = None
|
14 |
+
|
15 |
+
# Tuable Parameters , All default values
|
16 |
+
self.tunable_params = {
|
17 |
+
'points_per_side': 32,
|
18 |
+
'pred_iou_thresh': 0.88,
|
19 |
+
'stability_score_thresh': 0.95,
|
20 |
+
'crop_n_layers': 0,
|
21 |
+
'crop_n_points_downscale_factor': 1,
|
22 |
+
'min_mask_region_area': 0
|
23 |
+
}
|
24 |
+
|
25 |
+
def set_mask_generator(self):
|
26 |
+
print("applying configs to model..")
|
27 |
+
if not os.path.exists(self.model_path):
|
28 |
+
print("No needed SAM model detected. downloading VIT H SAM model....")
|
29 |
+
download_sam_model_url()
|
30 |
+
|
31 |
+
self.model = sam_model_registry["default"](checkpoint=self.model_path)
|
32 |
+
self.model.to(device=self.device)
|
33 |
+
self.mask_generator = SamAutomaticMaskGenerator(
|
34 |
+
self.model,
|
35 |
+
points_per_side=self.tunable_params['points_per_side'],
|
36 |
+
pred_iou_thresh=self.tunable_params['pred_iou_thresh'],
|
37 |
+
stability_score_thresh=self.tunable_params['stability_score_thresh'],
|
38 |
+
crop_n_layers=self.tunable_params['crop_n_layers'],
|
39 |
+
crop_n_points_downscale_factor=self.tunable_params['crop_n_points_downscale_factor'],
|
40 |
+
min_mask_region_area=self.tunable_params['min_mask_region_area'],
|
41 |
+
output_mode="coco_rle",
|
42 |
+
)
|
43 |
+
|
44 |
+
def generate_mask(self, image):
|
45 |
+
return [self.mask_generator.generate(image)]
|
46 |
+
|
47 |
+
def generate_mask_app(self, image, *params):
|
48 |
+
tunable_params = {
|
49 |
+
'points_per_side': int(params[0]),
|
50 |
+
'pred_iou_thresh': float(params[1]),
|
51 |
+
'stability_score_thresh': float(params[2]),
|
52 |
+
'crop_n_layers': int(params[3]),
|
53 |
+
'crop_n_points_downscale_factor': int(params[4]),
|
54 |
+
'min_mask_region_area': int(params[5]),
|
55 |
+
}
|
56 |
+
|
57 |
+
if self.model is None or self.mask_generator is None or self.tunable_params != tunable_params:
|
58 |
+
self.tunable_params = tunable_params
|
59 |
+
self.set_mask_generator()
|
60 |
+
|
61 |
+
masks = self.mask_generator.generate(image)
|
62 |
+
save_psd_with_masks(image, masks)
|
63 |
+
combined_image = create_mask_combined_images(image, masks)
|
64 |
+
gallery = create_mask_gallery(image, masks)
|
65 |
+
return [combined_image] + gallery
|
modules/ui_utils.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
|
4 |
+
def open_folder(folder_path):
|
5 |
+
if os.path.exists(folder_path):
|
6 |
+
os.system(f"start {folder_path}")
|
7 |
+
else:
|
8 |
+
print(f"The folder {folder_path} does not exist.")
|
outputs/psd/psd file will be saved here.txt
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
2 |
+
torch
|
3 |
+
--extra-index-url https://download.pytorch.org/whl/cu117
|
4 |
+
torchvision
|
5 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
6 |
+
opencv-python
|
7 |
+
pycocotools
|
8 |
+
matplotlib
|
9 |
+
onnxruntime
|
10 |
+
onnx
|
11 |
+
gradio
|
12 |
+
pytoshop==1.2.0
|
screenshot.png
ADDED
![]() |