SkalskiP commited on
Commit
ec0b3c1
1 Parent(s): 44b03d2

initial version of SAM2 space

Browse files
Files changed (2) hide show
  1. app.py +46 -11
  2. utils/models.py +15 -1
app.py CHANGED
@@ -6,27 +6,48 @@ import supervision as sv
6
  import torch
7
  from PIL import Image
8
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
9
- from sam2.build_sam import build_sam2
 
10
 
11
  MARKDOWN = """
12
  # Segment Anything Model 2 🔥
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable
15
  visual segmentation in both images and videos. The model extends its functionality to
16
  video by treating images as single-frame videos. Its design, a simple transformer
17
- architecture with streaming memory, enables real-time video processing.
 
 
 
 
18
  """
 
 
 
 
 
19
 
20
- # DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
- DEVICE = torch.device('cuda')
22
-
23
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
24
- CHECKPOINT = "checkpoints/sam2_hiera_large.pt"
25
- CONFIG = "sam2_hiera_l.yaml"
26
 
27
- sam2_model = build_sam2(CONFIG, CHECKPOINT, device=DEVICE, apply_postprocessing=False)
28
 
29
- def process(image_input) -> Optional[Image.Image]:
 
30
  mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
31
  image = np.array(image_input.convert("RGB"))
32
  sam_result = mask_generator.generate(image)
@@ -36,17 +57,31 @@ def process(image_input) -> Optional[Image.Image]:
36
 
37
  with gr.Blocks() as demo:
38
  gr.Markdown(MARKDOWN)
39
-
 
 
 
 
 
 
40
  with gr.Row():
41
  with gr.Column():
42
  image_input_component = gr.Image(type='pil', label='Upload image')
43
  submit_button_component = gr.Button(value='Submit', variant='primary')
44
  with gr.Column():
45
  image_output_component = gr.Image(type='pil', label='Image Output')
 
 
 
 
 
 
 
 
46
 
47
  submit_button_component.click(
48
  fn=process,
49
- inputs=[image_input_component],
50
  outputs=[image_output_component]
51
  )
52
 
 
6
  import torch
7
  from PIL import Image
8
  from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
9
+
10
+ from utils.models import load_models, CHECKPOINT_NAMES
11
 
12
  MARKDOWN = """
13
  # Segment Anything Model 2 🔥
14
+ <div>
15
+ <a href="https://github.com/facebookresearch/segment-anything-2">
16
+ <img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;">
17
+ </a>
18
+ <a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb">
19
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;">
20
+ </a>
21
+ <a href="https://blog.roboflow.com/what-is-segment-anything-2/">
22
+ <img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;">
23
+ </a>
24
+ <a href="https://www.youtube.com/watch?v=Dv003fTyO-Y">
25
+ <img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;">
26
+ </a>
27
+ </div>
28
 
29
  Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable
30
  visual segmentation in both images and videos. The model extends its functionality to
31
  video by treating images as single-frame videos. Its design, a simple transformer
32
+ architecture with streaming memory, enables real-time video processing. A
33
+ model-in-the-loop data engine, which enhances the model and data through user
34
+ interaction, was built to collect the SA-V dataset, the largest video segmentation
35
+ dataset to date. SAM 2, trained on this extensive dataset, delivers robust performance
36
+ across diverse tasks and visual domains.
37
  """
38
+ EXAMPLES = [
39
+ ["tiny", "https://media.roboflow.com/notebooks/examples/dog-2.jpeg"],
40
+ ["small", "https://media.roboflow.com/notebooks/examples/dog-3.jpeg"],
41
+ ["large", "https://media.roboflow.com/notebooks/examples/dog-3.jpeg"],
42
+ ]
43
 
44
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
45
  MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX)
46
+ MODELS = load_models(device=DEVICE)
 
47
 
 
48
 
49
+ def process(checkpoint_dropdown, image_input) -> Optional[Image.Image]:
50
+ sam2_model = MODELS[checkpoint_dropdown]
51
  mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
52
  image = np.array(image_input.convert("RGB"))
53
  sam_result = mask_generator.generate(image)
 
57
 
58
  with gr.Blocks() as demo:
59
  gr.Markdown(MARKDOWN)
60
+ with gr.Row():
61
+ checkpoint_dropdown_component = gr.Dropdown(
62
+ choices=CHECKPOINT_NAMES,
63
+ value=CHECKPOINT_NAMES[0],
64
+ label="Checkpoint", info="Select a SAM2 checkpoint to use.",
65
+ interactive=True
66
+ )
67
  with gr.Row():
68
  with gr.Column():
69
  image_input_component = gr.Image(type='pil', label='Upload image')
70
  submit_button_component = gr.Button(value='Submit', variant='primary')
71
  with gr.Column():
72
  image_output_component = gr.Image(type='pil', label='Image Output')
73
+ with gr.Row():
74
+ gr.Examples(
75
+ fn=process,
76
+ examples=EXAMPLES,
77
+ inputs=[checkpoint_dropdown_component, image_input_component],
78
+ outputs=[image_output_component],
79
+ run_on_click=True
80
+ )
81
 
82
  submit_button_component.click(
83
  fn=process,
84
+ inputs=[checkpoint_dropdown_component, image_input_component],
85
  outputs=[image_output_component]
86
  )
87
 
utils/models.py CHANGED
@@ -1,6 +1,20 @@
 
 
 
 
 
 
 
1
  CHECKPOINTS = {
2
  "tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
3
  "small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
4
  "base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
5
  "large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
6
- }
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from typing import Dict, Any
4
+ from sam2.build_sam import build_sam2
5
+
6
+ CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
7
+
8
  CHECKPOINTS = {
9
  "tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
10
  "small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
11
  "base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
12
  "large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
13
+ }
14
+
15
+
16
+ def load_models(device: torch.device) -> Dict[str, Any]:
17
+ models = {}
18
+ for key, (config, checkpoint) in CHECKPOINTS.items():
19
+ models[key] = build_sam2(config, checkpoint, device=device, apply_postprocessing=False)
20
+ return models