chongzhou commited on
Commit
69420c9
1 Parent(s): 9dd21b3

switch to ONNX backend

Browse files
app.py CHANGED
@@ -1,14 +1,18 @@
1
  # Code credit: [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM).
2
 
 
3
  import gradio as gr
4
  import numpy as np
5
- import torch
6
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
7
  from PIL import ImageDraw
8
  from utils.tools_gradio import fast_process
9
  import copy
10
  import argparse
11
 
 
 
 
12
  parser = argparse.ArgumentParser(
13
  description="Host EdgeSAM as a local web service."
14
  )
@@ -16,13 +20,19 @@ parser.add_argument(
16
  "--checkpoint",
17
  default="weights/edge_sam_3x.pth",
18
  type=str,
19
- help="The path to the EdgeSAM model checkpoint."
 
 
 
 
 
 
20
  )
21
  parser.add_argument(
22
- "--enable-everything-mode",
23
- action="store_true",
24
- help="Since EdgeSAM follows the same encoder-decoder architecture as SAM, the everything mode will infer the "
25
- "decoder 32x32=1024 times, which is inefficient, thus a longer processing time is expected.",
26
  )
27
  parser.add_argument(
28
  "--server-name",
@@ -39,12 +49,32 @@ parser.add_argument(
39
  args = parser.parse_args()
40
 
41
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
- sam = sam_model_registry["edge_sam"](checkpoint=args.checkpoint, upsample_mode="bicubic")
43
- sam = sam.to(device=device)
44
- sam.eval()
 
 
 
 
45
 
46
- mask_generator = SamAutomaticMaskGenerator(sam)
47
- predictor = SamPredictor(sam)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  # Description
50
  title = "<center><strong><font size='8'>EdgeSAM<font></strong> <a href='https://github.com/chongzhou96/EdgeSAM'><font size='6'>[GitHub]</font></a> </center>"
@@ -68,35 +98,6 @@ description_b = """ # Instructions for box mode
68
 
69
  """
70
 
71
- description_e = """ # Everything mode is NOT recommended.
72
-
73
- Since EdgeSAM follows the same encoder-decoder architecture as SAM, the everything mode will infer the decoder 32x32=1024 times, which is inefficient, thus a longer processing time is expected.
74
-
75
- 1. Upload an image or click one of the provided examples.
76
- 2. Click Start to get the segmentation mask.
77
- 3. The Reset button resets the image and masks.
78
-
79
- """
80
-
81
- examples = [
82
- ["assets/1.jpeg"],
83
- ["assets/2.jpeg"],
84
- ["assets/3.jpeg"],
85
- ["assets/4.jpeg"],
86
- ["assets/5.jpeg"],
87
- ["assets/6.jpeg"],
88
- ["assets/7.jpeg"],
89
- ["assets/8.jpeg"],
90
- ["assets/9.jpeg"],
91
- ["assets/10.jpeg"],
92
- ["assets/11.jpeg"],
93
- ["assets/12.jpeg"],
94
- ["assets/13.jpeg"],
95
- ["assets/14.jpeg"],
96
- ["assets/15.jpeg"],
97
- ["assets/16.jpeg"]
98
- ]
99
-
100
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
101
 
102
  global_points = []
@@ -119,6 +120,7 @@ def reset():
119
  global_image_with_prompt = None
120
  return None
121
 
 
122
  def reset_all():
123
  global global_points
124
  global global_point_label
@@ -130,10 +132,7 @@ def reset_all():
130
  global_box = []
131
  global_image = None
132
  global_image_with_prompt = None
133
- if args.enable_everything_mode:
134
- return None, None, None
135
- else:
136
- return None, None
137
 
138
 
139
  def clear():
@@ -185,14 +184,15 @@ def convert_box(xyxy):
185
  xyxy[1][1] = max_y
186
  return xyxy
187
 
 
188
  def segment_with_points(
189
- label,
190
- evt: gr.SelectData,
191
- input_size=1024,
192
- better_quality=False,
193
- withContours=True,
194
- use_retina=True,
195
- mask_random_color=False,
196
  ):
197
  global global_points
198
  global global_point_label
@@ -213,26 +213,30 @@ def segment_with_points(
213
  )
214
  image = global_image_with_prompt
215
 
216
- global_points_np = np.array(global_points)
217
- global_point_label_np = np.array(global_point_label)
218
-
219
- num_multimask_outputs = 4
220
-
221
- masks, scores, logits = predictor.predict(
222
- point_coords=global_points_np,
223
- point_labels=global_point_label_np,
224
- num_multimask_outputs=num_multimask_outputs,
225
- use_stability_score=True
226
- )
 
 
 
 
 
 
 
227
 
228
  print(f'scores: {scores}')
229
  area = masks.sum(axis=(1, 2))
230
  print(f'area: {area}')
231
 
232
- if num_multimask_outputs == 1:
233
- annotations = masks
234
- else:
235
- annotations = np.expand_dims(masks[scores.argmax()], axis=0)
236
 
237
  seg = fast_process(
238
  annotations=annotations,
@@ -250,12 +254,12 @@ def segment_with_points(
250
 
251
 
252
  def segment_with_box(
253
- evt: gr.SelectData,
254
- input_size=1024,
255
- better_quality=False,
256
- withContours=True,
257
- use_retina=True,
258
- mask_random_color=False,
259
  ):
260
  global global_box
261
  global global_image
@@ -292,12 +296,20 @@ def segment_with_box(
292
  )
293
 
294
  global_box_np = np.array(global_box)
295
-
296
- masks, scores, logits = predictor.predict(
297
- box=global_box_np,
298
- num_multimask_outputs=1,
299
- )
300
- annotations = masks
 
 
 
 
 
 
 
 
301
 
302
  seg = fast_process(
303
  annotations=annotations,
@@ -313,44 +325,10 @@ def segment_with_box(
313
  return seg
314
  return image
315
 
316
-
317
- def segment_everything(
318
- image,
319
- input_size=1024,
320
- better_quality=False,
321
- withContours=True,
322
- use_retina=True,
323
- mask_random_color=True,
324
- ):
325
- nd_image = np.array(image)
326
- masks = mask_generator.generate(nd_image)
327
- annotations = masks
328
- seg = fast_process(
329
- annotations=annotations,
330
- image=image,
331
- device=device,
332
- scale=(1024 // input_size),
333
- better_quality=better_quality,
334
- mask_random_color=mask_random_color,
335
- bbox=None,
336
- use_retina=use_retina,
337
- withContours=withContours,
338
- )
339
-
340
- return seg
341
-
342
-
343
  img_p = gr.Image(label="Input with points", type="pil")
344
  img_b = gr.Image(label="Input with box", type="pil")
345
- img_e = gr.Image(label="Input (everything)", type="pil")
346
-
347
- if args.enable_everything_mode:
348
- all_outputs = [img_p, img_b, img_e]
349
- else:
350
- all_outputs = [img_p, img_b]
351
 
352
  with gr.Blocks(css=css, title="EdgeSAM") as demo:
353
-
354
  with gr.Row():
355
  with gr.Column(scale=1):
356
  # Title
@@ -410,53 +388,24 @@ with gr.Blocks(css=css, title="EdgeSAM") as demo:
410
  run_on_click=True
411
  )
412
 
413
- if args.enable_everything_mode:
414
- with gr.Tab("Everything mode") as tab_e:
415
- # Images
416
- with gr.Row(variant="panel"):
417
- with gr.Column(scale=1):
418
- img_e.render()
419
- with gr.Column(scale=1):
420
- with gr.Row():
421
- with gr.Column():
422
- segment_btn_e = gr.Button("Start", variant="primary")
423
- reset_btn_e = gr.Button("Reset", variant="secondary")
424
- gr.Markdown(description_e)
425
-
426
- # Submit & Clear
427
- with gr.Row():
428
- with gr.Column():
429
- gr.Markdown("Try some of the examples below ⬇️")
430
- gr.Examples(
431
- examples=examples,
432
- inputs=[img_e],
433
- examples_per_page=8,
434
- )
435
-
436
  with gr.Row():
437
  with gr.Column(scale=1):
438
- gr.Markdown("<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
 
439
 
440
  img_p.upload(on_image_upload, img_p, [img_p])
441
  img_p.select(segment_with_points, [add_or_remove], img_p)
442
 
443
  clear_btn_p.click(clear, outputs=[img_p])
444
  reset_btn_p.click(reset, outputs=[img_p])
445
- tab_p.select(fn=reset_all, outputs=all_outputs)
446
 
447
  img_b.upload(on_image_upload, img_b, [img_b])
448
  img_b.select(segment_with_box, outputs=[img_b])
449
 
450
  clear_btn_b.click(clear, outputs=[img_b])
451
  reset_btn_b.click(reset, outputs=[img_b])
452
- tab_b.select(fn=reset_all, outputs=all_outputs)
453
-
454
- if args.enable_everything_mode:
455
- segment_btn_e.click(
456
- segment_everything, inputs=[img_e], outputs=img_e
457
- )
458
- reset_btn_e.click(reset, outputs=[img_e])
459
- tab_e.select(fn=reset_all, outputs=all_outputs)
460
 
461
  demo.queue()
462
  # demo.launch(server_name=args.server_name, server_port=args.port)
 
1
  # Code credit: [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM).
2
 
3
+ import torch
4
  import gradio as gr
5
  import numpy as np
6
+ from segment_anything import sam_model_registry, SamPredictor
7
+ from segment_anything.onnx import SamPredictorONNX
8
  from PIL import ImageDraw
9
  from utils.tools_gradio import fast_process
10
  import copy
11
  import argparse
12
 
13
+ # Use ONNX to speed up the inference.
14
+ ENABLE_ONNX = True
15
+
16
  parser = argparse.ArgumentParser(
17
  description="Host EdgeSAM as a local web service."
18
  )
 
20
  "--checkpoint",
21
  default="weights/edge_sam_3x.pth",
22
  type=str,
23
+ help="The path to the PyTorch checkpoint of EdgeSAM."
24
+ )
25
+ parser.add_argument(
26
+ "--encoder-onnx-path",
27
+ default="weights/edge_sam_3x_encoder.onnx",
28
+ type=str,
29
+ help="The path to the ONNX model of EdgeSAM's encoder."
30
  )
31
  parser.add_argument(
32
+ "--decoder-onnx-path",
33
+ default="weights/edge_sam_3x_decoder.onnx",
34
+ type=str,
35
+ help="The path to the ONNX model of EdgeSAM's decoder."
36
  )
37
  parser.add_argument(
38
  "--server-name",
 
49
  args = parser.parse_args()
50
 
51
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52
+ if ENABLE_ONNX:
53
+ predictor = SamPredictorONNX(args.encoder_onnx_path, args.decoder_onnx_path)
54
+ else:
55
+ sam = sam_model_registry["edge_sam"](checkpoint=args.checkpoint, upsample_mode="bicubic")
56
+ sam = sam.to(device=device)
57
+ sam.eval()
58
+ predictor = SamPredictor(sam)
59
 
60
+ examples = [
61
+ ["assets/1.jpeg"],
62
+ ["assets/2.jpeg"],
63
+ ["assets/3.jpeg"],
64
+ ["assets/4.jpeg"],
65
+ ["assets/5.jpeg"],
66
+ ["assets/6.jpeg"],
67
+ ["assets/7.jpeg"],
68
+ ["assets/8.jpeg"],
69
+ ["assets/9.jpeg"],
70
+ ["assets/10.jpeg"],
71
+ ["assets/11.jpeg"],
72
+ ["assets/12.jpeg"],
73
+ ["assets/13.jpeg"],
74
+ ["assets/14.jpeg"],
75
+ ["assets/15.jpeg"],
76
+ ["assets/16.jpeg"]
77
+ ]
78
 
79
  # Description
80
  title = "<center><strong><font size='8'>EdgeSAM<font></strong> <a href='https://github.com/chongzhou96/EdgeSAM'><font size='6'>[GitHub]</font></a> </center>"
 
98
 
99
  """
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
102
 
103
  global_points = []
 
120
  global_image_with_prompt = None
121
  return None
122
 
123
+
124
  def reset_all():
125
  global global_points
126
  global global_point_label
 
132
  global_box = []
133
  global_image = None
134
  global_image_with_prompt = None
135
+ return None, None
 
 
 
136
 
137
 
138
  def clear():
 
184
  xyxy[1][1] = max_y
185
  return xyxy
186
 
187
+
188
  def segment_with_points(
189
+ label,
190
+ evt: gr.SelectData,
191
+ input_size=1024,
192
+ better_quality=False,
193
+ withContours=True,
194
+ use_retina=True,
195
+ mask_random_color=False,
196
  ):
197
  global global_points
198
  global global_point_label
 
213
  )
214
  image = global_image_with_prompt
215
 
216
+ if ENABLE_ONNX:
217
+ global_points_np = np.array(global_points)[None]
218
+ global_point_label_np = np.array(global_point_label)[None]
219
+ masks, scores, _ = predictor.predict(
220
+ point_coords=global_points_np,
221
+ point_labels=global_point_label_np,
222
+ )
223
+ masks = masks.squeeze(0)
224
+ scores = scores.squeeze(0)
225
+ else:
226
+ global_points_np = np.array(global_points)
227
+ global_point_label_np = np.array(global_point_label)
228
+ masks, scores, logits = predictor.predict(
229
+ point_coords=global_points_np,
230
+ point_labels=global_point_label_np,
231
+ num_multimask_outputs=4,
232
+ use_stability_score=True
233
+ )
234
 
235
  print(f'scores: {scores}')
236
  area = masks.sum(axis=(1, 2))
237
  print(f'area: {area}')
238
 
239
+ annotations = np.expand_dims(masks[scores.argmax()], axis=0)
 
 
 
240
 
241
  seg = fast_process(
242
  annotations=annotations,
 
254
 
255
 
256
  def segment_with_box(
257
+ evt: gr.SelectData,
258
+ input_size=1024,
259
+ better_quality=False,
260
+ withContours=True,
261
+ use_retina=True,
262
+ mask_random_color=False,
263
  ):
264
  global global_box
265
  global global_image
 
296
  )
297
 
298
  global_box_np = np.array(global_box)
299
+ if ENABLE_ONNX:
300
+ point_coords = global_box_np.reshape(2, 2)[None]
301
+ point_labels = np.array([2, 3])[None]
302
+ masks, _, _ = predictor.predict(
303
+ point_coords=point_coords,
304
+ point_labels=point_labels,
305
+ )
306
+ annotations = masks[:, 0, :, :]
307
+ else:
308
+ masks, scores, _ = predictor.predict(
309
+ box=global_box_np,
310
+ num_multimask_outputs=1,
311
+ )
312
+ annotations = masks
313
 
314
  seg = fast_process(
315
  annotations=annotations,
 
325
  return seg
326
  return image
327
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
  img_p = gr.Image(label="Input with points", type="pil")
329
  img_b = gr.Image(label="Input with box", type="pil")
 
 
 
 
 
 
330
 
331
  with gr.Blocks(css=css, title="EdgeSAM") as demo:
 
332
  with gr.Row():
333
  with gr.Column(scale=1):
334
  # Title
 
388
  run_on_click=True
389
  )
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  with gr.Row():
392
  with gr.Column(scale=1):
393
+ gr.Markdown(
394
+ "<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
395
 
396
  img_p.upload(on_image_upload, img_p, [img_p])
397
  img_p.select(segment_with_points, [add_or_remove], img_p)
398
 
399
  clear_btn_p.click(clear, outputs=[img_p])
400
  reset_btn_p.click(reset, outputs=[img_p])
401
+ tab_p.select(fn=reset_all, outputs=[img_p, img_b])
402
 
403
  img_b.upload(on_image_upload, img_b, [img_b])
404
  img_b.select(segment_with_box, outputs=[img_b])
405
 
406
  clear_btn_b.click(clear, outputs=[img_b])
407
  reset_btn_b.click(reset, outputs=[img_b])
408
+ tab_b.select(fn=reset_all, outputs=[img_p, img_b])
 
 
 
 
 
 
 
409
 
410
  demo.queue()
411
  # demo.launch(server_name=args.server_name, server_port=args.port)
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  torch
2
  torchvision
3
  opencv-python
4
- timm
 
 
1
  torch
2
  torchvision
3
  opencv-python
4
+ timm
5
+ onnxruntime
segment_anything/onnx/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .predictor_onnx import SamPredictorONNX
segment_anything/onnx/predictor_onnx.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import cv2
9
+
10
+ import onnxruntime
11
+ from typing import Optional, Tuple
12
+
13
+ from ..utils.transforms import ResizeLongestSide
14
+
15
+
16
+ class SamPredictorONNX:
17
+ mask_threshold: float = 0.0
18
+ image_format: str = "RGB"
19
+ img_size = 1024
20
+ pixel_mean = np.array([123.675, 116.28, 103.53])[None, :, None, None]
21
+ pixel_std = np.array([58.395, 57.12, 57.375])[None, :, None, None]
22
+
23
+ def __init__(
24
+ self,
25
+ encoder_path: str,
26
+ decoder_path: str
27
+ ) -> None:
28
+ super().__init__()
29
+ self.encoder = onnxruntime.InferenceSession(encoder_path)
30
+ self.decoder = onnxruntime.InferenceSession(decoder_path)
31
+
32
+ # Set the execution provider to GPU if available
33
+ if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
34
+ self.encoder.set_providers(['CUDAExecutionProvider'])
35
+ self.decoder.set_providers(['CUDAExecutionProvider'])
36
+
37
+ self.transform = ResizeLongestSide(self.img_size)
38
+ self.reset_image()
39
+
40
+ def set_image(
41
+ self,
42
+ image: np.ndarray,
43
+ image_format: str = "RGB",
44
+ ) -> None:
45
+ assert image_format in [
46
+ "RGB",
47
+ "BGR",
48
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
49
+ if image_format != self.image_format:
50
+ image = image[..., ::-1]
51
+
52
+ # Transform the image to the form expected by the model
53
+ input_image = self.transform.apply_image(image)
54
+ input_image = input_image.transpose(2, 0, 1)[None, :, :, :]
55
+ self.reset_image()
56
+ self.original_size = image.shape[:2]
57
+ self.input_size = tuple(input_image.shape[-2:])
58
+ input_image = self.preprocess(input_image).astype(np.float32)
59
+ outputs = self.encoder.run(None, {'image': input_image})
60
+ self.features = outputs[0]
61
+ self.is_image_set = True
62
+
63
+ def predict(
64
+ self,
65
+ point_coords: Optional[np.ndarray] = None,
66
+ point_labels: Optional[np.ndarray] = None,
67
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
68
+ if not self.is_image_set:
69
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
70
+
71
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
72
+ outputs = self.decoder.run(None, {
73
+ 'image_embeddings': self.features,
74
+ 'point_coords': point_coords.astype(np.float32),
75
+ 'point_labels': point_labels.astype(np.float32)
76
+ })
77
+ scores, low_res_masks = outputs[0], outputs[1]
78
+ masks = self.postprocess_masks(low_res_masks)
79
+ masks = masks > self.mask_threshold
80
+
81
+ return masks, scores, low_res_masks
82
+
83
+ def reset_image(self) -> None:
84
+ """Resets the currently set image."""
85
+ self.is_image_set = False
86
+ self.features = None
87
+ self.orig_h = None
88
+ self.orig_w = None
89
+ self.input_h = None
90
+ self.input_w = None
91
+
92
+ def preprocess(self, x: np.ndarray):
93
+ x = (x - self.pixel_mean) / self.pixel_std
94
+ h, w = x.shape[-2:]
95
+ padh = self.img_size - h
96
+ padw = self.img_size - w
97
+ x = np.pad(x, ((0, 0), (0, 0), (0, padh), (0, padw)), mode='constant', constant_values=0)
98
+ return x
99
+
100
+ def postprocess_masks(self, mask: np.ndarray):
101
+ mask = mask.squeeze(0).transpose(1, 2, 0)
102
+ mask = cv2.resize(mask, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR)
103
+ mask = mask[:self.input_size[0], :self.input_size[1], :]
104
+ mask = cv2.resize(mask, (self.original_size[1], self.original_size[0]), interpolation=cv2.INTER_LINEAR)
105
+ mask = mask.transpose(2, 0, 1)[None, :, :, :]
106
+ return mask