SkalskiP commited on
Commit
af5888a
1 Parent(s): aa009f7

prompting with boxes added

Browse files
Files changed (3) hide show
  1. app.py +72 -47
  2. requirements.txt +1 -1
  3. utils/models.py +18 -7
app.py CHANGED
@@ -5,9 +5,10 @@ import numpy as np
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
8
- from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
9
 
10
- from utils.models import load_models, CHECKPOINT_NAMES
 
11
 
12
  MARKDOWN = """
13
  # Segment Anything Model 2 🔥
@@ -27,35 +28,50 @@ MARKDOWN = """
27
  </div>
28
 
29
  Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable
30
- visual segmentation in both images and videos. The model extends its functionality to
31
- video by treating images as single-frame videos. Its design, a simple transformer
32
- architecture with streaming memory, enables real-time video processing. A
33
- model-in-the-loop data engine, which enhances the model and data through user
34
- interaction, was built to collect the SA-V dataset, the largest video segmentation
35
- dataset to date. SAM 2, trained on this extensive dataset, delivers robust performance
36
- across diverse tasks and visual domains.
37
  """
38
- EXAMPLES = [
39
- ["tiny", "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", 16],
40
- ["small", "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", 16],
41
- ["large", "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", 16],
42
- ["large", "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", 64],
43
- ]
44
 
45
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
47
- MODELS = load_models(device=DEVICE)
48
 
49
 
50
- def process(checkpoint_dropdown, image_input, points_per_side) -> Optional[Image.Image]:
51
- model = MODELS[checkpoint_dropdown]
52
- mask_generator = SAM2AutomaticMaskGenerator(
53
- model=model,
54
- points_per_side=points_per_side)
55
- image = np.array(image_input.convert("RGB"))
56
- sam_result = mask_generator.generate(image)
57
- detections = sv.Detections.from_sam(sam_result=sam_result)
58
- return MASK_ANNOTATOR.annotate(scene=image_input, detections=detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
 
61
  with gr.Blocks() as demo:
@@ -67,39 +83,48 @@ with gr.Blocks() as demo:
67
  label="Checkpoint", info="Select a SAM2 checkpoint to use.",
68
  interactive=True
69
  )
70
- points_per_side_component = gr.Slider(
71
- minimum=16,
72
- maximum=64,
73
- value=16,
74
- step=16,
75
- label="Points per side",
76
- info="the number of points to be sampled along one side of the image."
 
77
  )
78
  with gr.Row():
79
  with gr.Column():
80
- image_input_component = gr.Image(type='pil', label='Upload image')
81
- submit_button_component = gr.Button(value='Submit', variant='primary')
 
 
 
 
82
  with gr.Column():
83
  image_output_component = gr.Image(type='pil', label='Image Output')
84
- with gr.Row():
85
- gr.Examples(
86
- fn=process,
87
- examples=EXAMPLES,
88
- inputs=[
89
- checkpoint_dropdown_component,
90
- image_input_component,
91
- points_per_side_component
92
- ],
93
- outputs=[image_output_component],
94
- run_on_click=True
95
- )
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  submit_button_component.click(
98
  fn=process,
99
  inputs=[
100
  checkpoint_dropdown_component,
 
101
  image_input_component,
102
- points_per_side_component
103
  ],
104
  outputs=[image_output_component]
105
  )
 
5
  import supervision as sv
6
  import torch
7
  from PIL import Image
8
+ from gradio_image_prompter import ImagePrompter
9
 
10
+ from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \
11
+ MASK_GENERATION_MODE, BOX_PROMPT_MODE
12
 
13
  MARKDOWN = """
14
  # Segment Anything Model 2 🔥
 
28
  </div>
29
 
30
  Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable
31
+ visual segmentation in both images and videos. **Video segmentation will be available
32
+ soon.**
 
 
 
 
 
33
  """
 
 
 
 
 
 
34
 
35
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
36
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
37
+ IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE)
38
 
39
 
40
+ def process(
41
+ checkpoint_dropdown,
42
+ mode_dropdown,
43
+ image_input,
44
+ image_prompter_input
45
+ ) -> Optional[Image.Image]:
46
+ if mode_dropdown == BOX_PROMPT_MODE:
47
+ image_input = image_prompter_input["image"]
48
+ prompt = image_prompter_input["points"]
49
+ if len(prompt) == 0:
50
+ return image_input
51
+
52
+ model = IMAGE_PREDICTORS[checkpoint_dropdown]
53
+ image = np.array(image_input.convert("RGB"))
54
+ box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt])
55
+
56
+ model.set_image(image)
57
+ masks, _, _ = model.predict(box=box, multimask_output=False)
58
+
59
+ # dirty fix; remove this later
60
+ if len(masks.shape) == 4:
61
+ masks = np.squeeze(masks)
62
+
63
+ detections = sv.Detections(
64
+ xyxy=sv.mask_to_xyxy(masks=masks),
65
+ mask=masks.astype(bool)
66
+ )
67
+ return MASK_ANNOTATOR.annotate(image_input, detections)
68
+
69
+ if mode_dropdown == MASK_GENERATION_MODE:
70
+ model = MASK_GENERATORS[checkpoint_dropdown]
71
+ image = np.array(image_input.convert("RGB"))
72
+ result = model.generate(image)
73
+ detections = sv.Detections.from_sam(result)
74
+ return MASK_ANNOTATOR.annotate(image_input, detections)
75
 
76
 
77
  with gr.Blocks() as demo:
 
83
  label="Checkpoint", info="Select a SAM2 checkpoint to use.",
84
  interactive=True
85
  )
86
+ mode_dropdown_component = gr.Dropdown(
87
+ choices=MODE_NAMES,
88
+ value=MODE_NAMES[0],
89
+ label="Mode",
90
+ info="Select a mode to use. `box prompt` if you want to generate masks for "
91
+ "selected objects, `mask generation` if you want to generate masks "
92
+ "for the whole image.",
93
+ interactive=True
94
  )
95
  with gr.Row():
96
  with gr.Column():
97
+ image_input_component = gr.Image(
98
+ type='pil', label='Upload image', visible=False)
99
+ image_prompter_input_component = ImagePrompter(
100
+ type='pil', label='Image prompt')
101
+ submit_button_component = gr.Button(
102
+ value='Submit', variant='primary')
103
  with gr.Column():
104
  image_output_component = gr.Image(type='pil', label='Image Output')
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+
107
+ def on_mode_dropdown_change(text):
108
+ return [
109
+ gr.Image(visible=text == MASK_GENERATION_MODE),
110
+ ImagePrompter(visible=text == BOX_PROMPT_MODE)
111
+ ]
112
+
113
+ mode_dropdown_component.change(
114
+ on_mode_dropdown_change,
115
+ inputs=[mode_dropdown_component],
116
+ outputs=[
117
+ image_input_component,
118
+ image_prompter_input_component
119
+ ]
120
+ )
121
  submit_button_component.click(
122
  fn=process,
123
  inputs=[
124
  checkpoint_dropdown_component,
125
+ mode_dropdown_component,
126
  image_input_component,
127
+ image_prompter_input_component,
128
  ],
129
  outputs=[image_output_component]
130
  )
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  samv2
2
  gradio
3
  supervision
4
- gradio_image_annotation
5
  opencv-python
 
1
  samv2
2
  gradio
3
  supervision
4
+ gradio_image_prompter
5
  opencv-python
utils/models.py CHANGED
@@ -1,10 +1,16 @@
1
- import torch
2
 
3
- from typing import Dict, Any
 
4
  from sam2.build_sam import build_sam2
 
5
 
6
- CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
 
 
 
7
 
 
8
  CHECKPOINTS = {
9
  "tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
10
  "small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
@@ -13,8 +19,13 @@ CHECKPOINTS = {
13
  }
14
 
15
 
16
- def load_models(device: torch.device) -> Dict[str, Any]:
17
- models = {}
 
 
 
18
  for key, (config, checkpoint) in CHECKPOINTS.items():
19
- models[key] = build_sam2(config, checkpoint, device=device, apply_postprocessing=False)
20
- return models
 
 
 
1
+ from typing import Dict, Tuple
2
 
3
+ import torch
4
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
5
  from sam2.build_sam import build_sam2
6
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
7
 
8
+ BOX_PROMPT_MODE = "box prompt"
9
+ MASK_GENERATION_MODE = "mask generation"
10
+ VIDEO_SEGMENTATION_MODE = "video segmentation"
11
+ MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE]
12
 
13
+ CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
14
  CHECKPOINTS = {
15
  "tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
16
  "small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
 
19
  }
20
 
21
 
22
+ def load_models(
23
+ device: torch.device
24
+ ) -> Tuple[Dict[str, SAM2ImagePredictor], Dict[str, SAM2AutomaticMaskGenerator]]:
25
+ image_predictors = {}
26
+ mask_generators = {}
27
  for key, (config, checkpoint) in CHECKPOINTS.items():
28
+ model = build_sam2(config, checkpoint, device=device)
29
+ image_predictors[key] = SAM2ImagePredictor(sam_model=model)
30
+ mask_generators[key] = SAM2AutomaticMaskGenerator(model=model)
31
+ return image_predictors, mask_generators