switch to ONNX backend
Browse files- app.py +95 -146
- requirements.txt +2 -1
- segment_anything/onnx/__init__.py +1 -0
- segment_anything/onnx/predictor_onnx.py +106 -0
app.py
CHANGED
@@ -1,14 +1,18 @@
|
|
1 |
# Code credit: [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM).
|
2 |
|
|
|
3 |
import gradio as gr
|
4 |
import numpy as np
|
5 |
-
import
|
6 |
-
from segment_anything import
|
7 |
from PIL import ImageDraw
|
8 |
from utils.tools_gradio import fast_process
|
9 |
import copy
|
10 |
import argparse
|
11 |
|
|
|
|
|
|
|
12 |
parser = argparse.ArgumentParser(
|
13 |
description="Host EdgeSAM as a local web service."
|
14 |
)
|
@@ -16,13 +20,19 @@ parser.add_argument(
|
|
16 |
"--checkpoint",
|
17 |
default="weights/edge_sam_3x.pth",
|
18 |
type=str,
|
19 |
-
help="The path to the
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
)
|
21 |
parser.add_argument(
|
22 |
-
"--
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
)
|
27 |
parser.add_argument(
|
28 |
"--server-name",
|
@@ -39,12 +49,32 @@ parser.add_argument(
|
|
39 |
args = parser.parse_args()
|
40 |
|
41 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
# Description
|
50 |
title = "<center><strong><font size='8'>EdgeSAM<font></strong> <a href='https://github.com/chongzhou96/EdgeSAM'><font size='6'>[GitHub]</font></a> </center>"
|
@@ -68,35 +98,6 @@ description_b = """ # Instructions for box mode
|
|
68 |
|
69 |
"""
|
70 |
|
71 |
-
description_e = """ # Everything mode is NOT recommended.
|
72 |
-
|
73 |
-
Since EdgeSAM follows the same encoder-decoder architecture as SAM, the everything mode will infer the decoder 32x32=1024 times, which is inefficient, thus a longer processing time is expected.
|
74 |
-
|
75 |
-
1. Upload an image or click one of the provided examples.
|
76 |
-
2. Click Start to get the segmentation mask.
|
77 |
-
3. The Reset button resets the image and masks.
|
78 |
-
|
79 |
-
"""
|
80 |
-
|
81 |
-
examples = [
|
82 |
-
["assets/1.jpeg"],
|
83 |
-
["assets/2.jpeg"],
|
84 |
-
["assets/3.jpeg"],
|
85 |
-
["assets/4.jpeg"],
|
86 |
-
["assets/5.jpeg"],
|
87 |
-
["assets/6.jpeg"],
|
88 |
-
["assets/7.jpeg"],
|
89 |
-
["assets/8.jpeg"],
|
90 |
-
["assets/9.jpeg"],
|
91 |
-
["assets/10.jpeg"],
|
92 |
-
["assets/11.jpeg"],
|
93 |
-
["assets/12.jpeg"],
|
94 |
-
["assets/13.jpeg"],
|
95 |
-
["assets/14.jpeg"],
|
96 |
-
["assets/15.jpeg"],
|
97 |
-
["assets/16.jpeg"]
|
98 |
-
]
|
99 |
-
|
100 |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
101 |
|
102 |
global_points = []
|
@@ -119,6 +120,7 @@ def reset():
|
|
119 |
global_image_with_prompt = None
|
120 |
return None
|
121 |
|
|
|
122 |
def reset_all():
|
123 |
global global_points
|
124 |
global global_point_label
|
@@ -130,10 +132,7 @@ def reset_all():
|
|
130 |
global_box = []
|
131 |
global_image = None
|
132 |
global_image_with_prompt = None
|
133 |
-
|
134 |
-
return None, None, None
|
135 |
-
else:
|
136 |
-
return None, None
|
137 |
|
138 |
|
139 |
def clear():
|
@@ -185,14 +184,15 @@ def convert_box(xyxy):
|
|
185 |
xyxy[1][1] = max_y
|
186 |
return xyxy
|
187 |
|
|
|
188 |
def segment_with_points(
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
):
|
197 |
global global_points
|
198 |
global global_point_label
|
@@ -213,26 +213,30 @@ def segment_with_points(
|
|
213 |
)
|
214 |
image = global_image_with_prompt
|
215 |
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
print(f'scores: {scores}')
|
229 |
area = masks.sum(axis=(1, 2))
|
230 |
print(f'area: {area}')
|
231 |
|
232 |
-
|
233 |
-
annotations = masks
|
234 |
-
else:
|
235 |
-
annotations = np.expand_dims(masks[scores.argmax()], axis=0)
|
236 |
|
237 |
seg = fast_process(
|
238 |
annotations=annotations,
|
@@ -250,12 +254,12 @@ def segment_with_points(
|
|
250 |
|
251 |
|
252 |
def segment_with_box(
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
):
|
260 |
global global_box
|
261 |
global global_image
|
@@ -292,12 +296,20 @@ def segment_with_box(
|
|
292 |
)
|
293 |
|
294 |
global_box_np = np.array(global_box)
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
seg = fast_process(
|
303 |
annotations=annotations,
|
@@ -313,44 +325,10 @@ def segment_with_box(
|
|
313 |
return seg
|
314 |
return image
|
315 |
|
316 |
-
|
317 |
-
def segment_everything(
|
318 |
-
image,
|
319 |
-
input_size=1024,
|
320 |
-
better_quality=False,
|
321 |
-
withContours=True,
|
322 |
-
use_retina=True,
|
323 |
-
mask_random_color=True,
|
324 |
-
):
|
325 |
-
nd_image = np.array(image)
|
326 |
-
masks = mask_generator.generate(nd_image)
|
327 |
-
annotations = masks
|
328 |
-
seg = fast_process(
|
329 |
-
annotations=annotations,
|
330 |
-
image=image,
|
331 |
-
device=device,
|
332 |
-
scale=(1024 // input_size),
|
333 |
-
better_quality=better_quality,
|
334 |
-
mask_random_color=mask_random_color,
|
335 |
-
bbox=None,
|
336 |
-
use_retina=use_retina,
|
337 |
-
withContours=withContours,
|
338 |
-
)
|
339 |
-
|
340 |
-
return seg
|
341 |
-
|
342 |
-
|
343 |
img_p = gr.Image(label="Input with points", type="pil")
|
344 |
img_b = gr.Image(label="Input with box", type="pil")
|
345 |
-
img_e = gr.Image(label="Input (everything)", type="pil")
|
346 |
-
|
347 |
-
if args.enable_everything_mode:
|
348 |
-
all_outputs = [img_p, img_b, img_e]
|
349 |
-
else:
|
350 |
-
all_outputs = [img_p, img_b]
|
351 |
|
352 |
with gr.Blocks(css=css, title="EdgeSAM") as demo:
|
353 |
-
|
354 |
with gr.Row():
|
355 |
with gr.Column(scale=1):
|
356 |
# Title
|
@@ -410,53 +388,24 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
|
|
410 |
run_on_click=True
|
411 |
)
|
412 |
|
413 |
-
if args.enable_everything_mode:
|
414 |
-
with gr.Tab("Everything mode") as tab_e:
|
415 |
-
# Images
|
416 |
-
with gr.Row(variant="panel"):
|
417 |
-
with gr.Column(scale=1):
|
418 |
-
img_e.render()
|
419 |
-
with gr.Column(scale=1):
|
420 |
-
with gr.Row():
|
421 |
-
with gr.Column():
|
422 |
-
segment_btn_e = gr.Button("Start", variant="primary")
|
423 |
-
reset_btn_e = gr.Button("Reset", variant="secondary")
|
424 |
-
gr.Markdown(description_e)
|
425 |
-
|
426 |
-
# Submit & Clear
|
427 |
-
with gr.Row():
|
428 |
-
with gr.Column():
|
429 |
-
gr.Markdown("Try some of the examples below ⬇️")
|
430 |
-
gr.Examples(
|
431 |
-
examples=examples,
|
432 |
-
inputs=[img_e],
|
433 |
-
examples_per_page=8,
|
434 |
-
)
|
435 |
-
|
436 |
with gr.Row():
|
437 |
with gr.Column(scale=1):
|
438 |
-
gr.Markdown(
|
|
|
439 |
|
440 |
img_p.upload(on_image_upload, img_p, [img_p])
|
441 |
img_p.select(segment_with_points, [add_or_remove], img_p)
|
442 |
|
443 |
clear_btn_p.click(clear, outputs=[img_p])
|
444 |
reset_btn_p.click(reset, outputs=[img_p])
|
445 |
-
tab_p.select(fn=reset_all, outputs=
|
446 |
|
447 |
img_b.upload(on_image_upload, img_b, [img_b])
|
448 |
img_b.select(segment_with_box, outputs=[img_b])
|
449 |
|
450 |
clear_btn_b.click(clear, outputs=[img_b])
|
451 |
reset_btn_b.click(reset, outputs=[img_b])
|
452 |
-
tab_b.select(fn=reset_all, outputs=
|
453 |
-
|
454 |
-
if args.enable_everything_mode:
|
455 |
-
segment_btn_e.click(
|
456 |
-
segment_everything, inputs=[img_e], outputs=img_e
|
457 |
-
)
|
458 |
-
reset_btn_e.click(reset, outputs=[img_e])
|
459 |
-
tab_e.select(fn=reset_all, outputs=all_outputs)
|
460 |
|
461 |
demo.queue()
|
462 |
# demo.launch(server_name=args.server_name, server_port=args.port)
|
|
|
1 |
# Code credit: [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM).
|
2 |
|
3 |
+
import torch
|
4 |
import gradio as gr
|
5 |
import numpy as np
|
6 |
+
from segment_anything import sam_model_registry, SamPredictor
|
7 |
+
from segment_anything.onnx import SamPredictorONNX
|
8 |
from PIL import ImageDraw
|
9 |
from utils.tools_gradio import fast_process
|
10 |
import copy
|
11 |
import argparse
|
12 |
|
13 |
+
# Use ONNX to speed up the inference.
|
14 |
+
ENABLE_ONNX = True
|
15 |
+
|
16 |
parser = argparse.ArgumentParser(
|
17 |
description="Host EdgeSAM as a local web service."
|
18 |
)
|
|
|
20 |
"--checkpoint",
|
21 |
default="weights/edge_sam_3x.pth",
|
22 |
type=str,
|
23 |
+
help="The path to the PyTorch checkpoint of EdgeSAM."
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--encoder-onnx-path",
|
27 |
+
default="weights/edge_sam_3x_encoder.onnx",
|
28 |
+
type=str,
|
29 |
+
help="The path to the ONNX model of EdgeSAM's encoder."
|
30 |
)
|
31 |
parser.add_argument(
|
32 |
+
"--decoder-onnx-path",
|
33 |
+
default="weights/edge_sam_3x_decoder.onnx",
|
34 |
+
type=str,
|
35 |
+
help="The path to the ONNX model of EdgeSAM's decoder."
|
36 |
)
|
37 |
parser.add_argument(
|
38 |
"--server-name",
|
|
|
49 |
args = parser.parse_args()
|
50 |
|
51 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
52 |
+
if ENABLE_ONNX:
|
53 |
+
predictor = SamPredictorONNX(args.encoder_onnx_path, args.decoder_onnx_path)
|
54 |
+
else:
|
55 |
+
sam = sam_model_registry["edge_sam"](checkpoint=args.checkpoint, upsample_mode="bicubic")
|
56 |
+
sam = sam.to(device=device)
|
57 |
+
sam.eval()
|
58 |
+
predictor = SamPredictor(sam)
|
59 |
|
60 |
+
examples = [
|
61 |
+
["assets/1.jpeg"],
|
62 |
+
["assets/2.jpeg"],
|
63 |
+
["assets/3.jpeg"],
|
64 |
+
["assets/4.jpeg"],
|
65 |
+
["assets/5.jpeg"],
|
66 |
+
["assets/6.jpeg"],
|
67 |
+
["assets/7.jpeg"],
|
68 |
+
["assets/8.jpeg"],
|
69 |
+
["assets/9.jpeg"],
|
70 |
+
["assets/10.jpeg"],
|
71 |
+
["assets/11.jpeg"],
|
72 |
+
["assets/12.jpeg"],
|
73 |
+
["assets/13.jpeg"],
|
74 |
+
["assets/14.jpeg"],
|
75 |
+
["assets/15.jpeg"],
|
76 |
+
["assets/16.jpeg"]
|
77 |
+
]
|
78 |
|
79 |
# Description
|
80 |
title = "<center><strong><font size='8'>EdgeSAM<font></strong> <a href='https://github.com/chongzhou96/EdgeSAM'><font size='6'>[GitHub]</font></a> </center>"
|
|
|
98 |
|
99 |
"""
|
100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
|
102 |
|
103 |
global_points = []
|
|
|
120 |
global_image_with_prompt = None
|
121 |
return None
|
122 |
|
123 |
+
|
124 |
def reset_all():
|
125 |
global global_points
|
126 |
global global_point_label
|
|
|
132 |
global_box = []
|
133 |
global_image = None
|
134 |
global_image_with_prompt = None
|
135 |
+
return None, None
|
|
|
|
|
|
|
136 |
|
137 |
|
138 |
def clear():
|
|
|
184 |
xyxy[1][1] = max_y
|
185 |
return xyxy
|
186 |
|
187 |
+
|
188 |
def segment_with_points(
|
189 |
+
label,
|
190 |
+
evt: gr.SelectData,
|
191 |
+
input_size=1024,
|
192 |
+
better_quality=False,
|
193 |
+
withContours=True,
|
194 |
+
use_retina=True,
|
195 |
+
mask_random_color=False,
|
196 |
):
|
197 |
global global_points
|
198 |
global global_point_label
|
|
|
213 |
)
|
214 |
image = global_image_with_prompt
|
215 |
|
216 |
+
if ENABLE_ONNX:
|
217 |
+
global_points_np = np.array(global_points)[None]
|
218 |
+
global_point_label_np = np.array(global_point_label)[None]
|
219 |
+
masks, scores, _ = predictor.predict(
|
220 |
+
point_coords=global_points_np,
|
221 |
+
point_labels=global_point_label_np,
|
222 |
+
)
|
223 |
+
masks = masks.squeeze(0)
|
224 |
+
scores = scores.squeeze(0)
|
225 |
+
else:
|
226 |
+
global_points_np = np.array(global_points)
|
227 |
+
global_point_label_np = np.array(global_point_label)
|
228 |
+
masks, scores, logits = predictor.predict(
|
229 |
+
point_coords=global_points_np,
|
230 |
+
point_labels=global_point_label_np,
|
231 |
+
num_multimask_outputs=4,
|
232 |
+
use_stability_score=True
|
233 |
+
)
|
234 |
|
235 |
print(f'scores: {scores}')
|
236 |
area = masks.sum(axis=(1, 2))
|
237 |
print(f'area: {area}')
|
238 |
|
239 |
+
annotations = np.expand_dims(masks[scores.argmax()], axis=0)
|
|
|
|
|
|
|
240 |
|
241 |
seg = fast_process(
|
242 |
annotations=annotations,
|
|
|
254 |
|
255 |
|
256 |
def segment_with_box(
|
257 |
+
evt: gr.SelectData,
|
258 |
+
input_size=1024,
|
259 |
+
better_quality=False,
|
260 |
+
withContours=True,
|
261 |
+
use_retina=True,
|
262 |
+
mask_random_color=False,
|
263 |
):
|
264 |
global global_box
|
265 |
global global_image
|
|
|
296 |
)
|
297 |
|
298 |
global_box_np = np.array(global_box)
|
299 |
+
if ENABLE_ONNX:
|
300 |
+
point_coords = global_box_np.reshape(2, 2)[None]
|
301 |
+
point_labels = np.array([2, 3])[None]
|
302 |
+
masks, _, _ = predictor.predict(
|
303 |
+
point_coords=point_coords,
|
304 |
+
point_labels=point_labels,
|
305 |
+
)
|
306 |
+
annotations = masks[:, 0, :, :]
|
307 |
+
else:
|
308 |
+
masks, scores, _ = predictor.predict(
|
309 |
+
box=global_box_np,
|
310 |
+
num_multimask_outputs=1,
|
311 |
+
)
|
312 |
+
annotations = masks
|
313 |
|
314 |
seg = fast_process(
|
315 |
annotations=annotations,
|
|
|
325 |
return seg
|
326 |
return image
|
327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
img_p = gr.Image(label="Input with points", type="pil")
|
329 |
img_b = gr.Image(label="Input with box", type="pil")
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
with gr.Blocks(css=css, title="EdgeSAM") as demo:
|
|
|
332 |
with gr.Row():
|
333 |
with gr.Column(scale=1):
|
334 |
# Title
|
|
|
388 |
run_on_click=True
|
389 |
)
|
390 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
with gr.Row():
|
392 |
with gr.Column(scale=1):
|
393 |
+
gr.Markdown(
|
394 |
+
"<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
|
395 |
|
396 |
img_p.upload(on_image_upload, img_p, [img_p])
|
397 |
img_p.select(segment_with_points, [add_or_remove], img_p)
|
398 |
|
399 |
clear_btn_p.click(clear, outputs=[img_p])
|
400 |
reset_btn_p.click(reset, outputs=[img_p])
|
401 |
+
tab_p.select(fn=reset_all, outputs=[img_p, img_b])
|
402 |
|
403 |
img_b.upload(on_image_upload, img_b, [img_b])
|
404 |
img_b.select(segment_with_box, outputs=[img_b])
|
405 |
|
406 |
clear_btn_b.click(clear, outputs=[img_b])
|
407 |
reset_btn_b.click(reset, outputs=[img_b])
|
408 |
+
tab_b.select(fn=reset_all, outputs=[img_p, img_b])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
demo.queue()
|
411 |
# demo.launch(server_name=args.server_name, server_port=args.port)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
torch
|
2 |
torchvision
|
3 |
opencv-python
|
4 |
-
timm
|
|
|
|
1 |
torch
|
2 |
torchvision
|
3 |
opencv-python
|
4 |
+
timm
|
5 |
+
onnxruntime
|
segment_anything/onnx/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .predictor_onnx import SamPredictorONNX
|
segment_anything/onnx/predictor_onnx.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
import onnxruntime
|
11 |
+
from typing import Optional, Tuple
|
12 |
+
|
13 |
+
from ..utils.transforms import ResizeLongestSide
|
14 |
+
|
15 |
+
|
16 |
+
class SamPredictorONNX:
|
17 |
+
mask_threshold: float = 0.0
|
18 |
+
image_format: str = "RGB"
|
19 |
+
img_size = 1024
|
20 |
+
pixel_mean = np.array([123.675, 116.28, 103.53])[None, :, None, None]
|
21 |
+
pixel_std = np.array([58.395, 57.12, 57.375])[None, :, None, None]
|
22 |
+
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
encoder_path: str,
|
26 |
+
decoder_path: str
|
27 |
+
) -> None:
|
28 |
+
super().__init__()
|
29 |
+
self.encoder = onnxruntime.InferenceSession(encoder_path)
|
30 |
+
self.decoder = onnxruntime.InferenceSession(decoder_path)
|
31 |
+
|
32 |
+
# Set the execution provider to GPU if available
|
33 |
+
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
|
34 |
+
self.encoder.set_providers(['CUDAExecutionProvider'])
|
35 |
+
self.decoder.set_providers(['CUDAExecutionProvider'])
|
36 |
+
|
37 |
+
self.transform = ResizeLongestSide(self.img_size)
|
38 |
+
self.reset_image()
|
39 |
+
|
40 |
+
def set_image(
|
41 |
+
self,
|
42 |
+
image: np.ndarray,
|
43 |
+
image_format: str = "RGB",
|
44 |
+
) -> None:
|
45 |
+
assert image_format in [
|
46 |
+
"RGB",
|
47 |
+
"BGR",
|
48 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
49 |
+
if image_format != self.image_format:
|
50 |
+
image = image[..., ::-1]
|
51 |
+
|
52 |
+
# Transform the image to the form expected by the model
|
53 |
+
input_image = self.transform.apply_image(image)
|
54 |
+
input_image = input_image.transpose(2, 0, 1)[None, :, :, :]
|
55 |
+
self.reset_image()
|
56 |
+
self.original_size = image.shape[:2]
|
57 |
+
self.input_size = tuple(input_image.shape[-2:])
|
58 |
+
input_image = self.preprocess(input_image).astype(np.float32)
|
59 |
+
outputs = self.encoder.run(None, {'image': input_image})
|
60 |
+
self.features = outputs[0]
|
61 |
+
self.is_image_set = True
|
62 |
+
|
63 |
+
def predict(
|
64 |
+
self,
|
65 |
+
point_coords: Optional[np.ndarray] = None,
|
66 |
+
point_labels: Optional[np.ndarray] = None,
|
67 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
68 |
+
if not self.is_image_set:
|
69 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
70 |
+
|
71 |
+
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
72 |
+
outputs = self.decoder.run(None, {
|
73 |
+
'image_embeddings': self.features,
|
74 |
+
'point_coords': point_coords.astype(np.float32),
|
75 |
+
'point_labels': point_labels.astype(np.float32)
|
76 |
+
})
|
77 |
+
scores, low_res_masks = outputs[0], outputs[1]
|
78 |
+
masks = self.postprocess_masks(low_res_masks)
|
79 |
+
masks = masks > self.mask_threshold
|
80 |
+
|
81 |
+
return masks, scores, low_res_masks
|
82 |
+
|
83 |
+
def reset_image(self) -> None:
|
84 |
+
"""Resets the currently set image."""
|
85 |
+
self.is_image_set = False
|
86 |
+
self.features = None
|
87 |
+
self.orig_h = None
|
88 |
+
self.orig_w = None
|
89 |
+
self.input_h = None
|
90 |
+
self.input_w = None
|
91 |
+
|
92 |
+
def preprocess(self, x: np.ndarray):
|
93 |
+
x = (x - self.pixel_mean) / self.pixel_std
|
94 |
+
h, w = x.shape[-2:]
|
95 |
+
padh = self.img_size - h
|
96 |
+
padw = self.img_size - w
|
97 |
+
x = np.pad(x, ((0, 0), (0, 0), (0, padh), (0, padw)), mode='constant', constant_values=0)
|
98 |
+
return x
|
99 |
+
|
100 |
+
def postprocess_masks(self, mask: np.ndarray):
|
101 |
+
mask = mask.squeeze(0).transpose(1, 2, 0)
|
102 |
+
mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
|
103 |
+
mask = mask[:self.input_size[0], :self.input_size[1], :]
|
104 |
+
mask = cv2.resize(mask, (self.original_size[1], self.original_size[0]), interpolation=cv2.INTER_LINEAR)
|
105 |
+
mask = mask.transpose(2, 0, 1)[None, :, :, :]
|
106 |
+
return mask
|