|
from typing import Optional |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import supervision as sv |
|
import torch |
|
from PIL import Image |
|
from gradio_image_prompter import ImagePrompter |
|
|
|
from utils.models import load_models, CHECKPOINT_NAMES, MODE_NAMES, \ |
|
MASK_GENERATION_MODE, BOX_PROMPT_MODE |
|
|
|
MARKDOWN = """ |
|
# Segment Anything Model 2 🔥 |
|
<div> |
|
<a href="https://github.com/facebookresearch/segment-anything-2"> |
|
<img src="https://badges.aleen42.com/src/github.svg" alt="GitHub" style="display:inline-block;"> |
|
</a> |
|
<a href="https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-images-with-sam-2.ipynb"> |
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Colab" style="display:inline-block;"> |
|
</a> |
|
<a href="https://blog.roboflow.com/what-is-segment-anything-2/"> |
|
<img src="https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg" alt="Roboflow" style="display:inline-block;"> |
|
</a> |
|
<a href="https://www.youtube.com/watch?v=Dv003fTyO-Y"> |
|
<img src="https://badges.aleen42.com/src/youtube.svg" alt="YouTube" style="display:inline-block;"> |
|
</a> |
|
</div> |
|
|
|
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable |
|
visual segmentation in both images and videos. **Video segmentation will be available |
|
soon.** |
|
""" |
|
EXAMPLES = [ |
|
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-2.jpeg", None], |
|
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-3.jpeg", None], |
|
["tiny", MASK_GENERATION_MODE, "https://media.roboflow.com/notebooks/examples/dog-4.jpeg", None], |
|
] |
|
|
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) |
|
IMAGE_PREDICTORS, MASK_GENERATORS = load_models(device=DEVICE) |
|
|
|
|
|
def process( |
|
checkpoint_dropdown, |
|
mode_dropdown, |
|
image_input, |
|
image_prompter_input |
|
) -> Optional[Image.Image]: |
|
if mode_dropdown == BOX_PROMPT_MODE: |
|
image_input = image_prompter_input["image"] |
|
prompt = image_prompter_input["points"] |
|
if len(prompt) == 0: |
|
return image_input |
|
|
|
model = IMAGE_PREDICTORS[checkpoint_dropdown] |
|
image = np.array(image_input.convert("RGB")) |
|
box = np.array([[x1, y1, x2, y2] for x1, y1, _, x2, y2, _ in prompt]) |
|
|
|
model.set_image(image) |
|
masks, _, _ = model.predict(box=box, multimask_output=False) |
|
|
|
|
|
if len(masks.shape) == 4: |
|
masks = np.squeeze(masks) |
|
|
|
detections = sv.Detections( |
|
xyxy=sv.mask_to_xyxy(masks=masks), |
|
mask=masks.astype(bool) |
|
) |
|
return MASK_ANNOTATOR.annotate(image_input, detections) |
|
|
|
if mode_dropdown == MASK_GENERATION_MODE: |
|
model = MASK_GENERATORS[checkpoint_dropdown] |
|
image = np.array(image_input.convert("RGB")) |
|
result = model.generate(image) |
|
detections = sv.Detections.from_sam(result) |
|
return MASK_ANNOTATOR.annotate(image_input, detections) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(MARKDOWN) |
|
with gr.Row(): |
|
checkpoint_dropdown_component = gr.Dropdown( |
|
choices=CHECKPOINT_NAMES, |
|
value=CHECKPOINT_NAMES[0], |
|
label="Checkpoint", info="Select a SAM2 checkpoint to use.", |
|
interactive=True |
|
) |
|
mode_dropdown_component = gr.Dropdown( |
|
choices=MODE_NAMES, |
|
value=MODE_NAMES[0], |
|
label="Mode", |
|
info="Select a mode to use. `box prompt` if you want to generate masks for " |
|
"selected objects, `mask generation` if you want to generate masks " |
|
"for the whole image.", |
|
interactive=True |
|
) |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input_component = gr.Image( |
|
type='pil', label='Upload image', visible=False) |
|
image_prompter_input_component = ImagePrompter( |
|
type='pil', label='Image prompt') |
|
submit_button_component = gr.Button( |
|
value='Submit', variant='primary') |
|
with gr.Column(): |
|
image_output_component = gr.Image(type='pil', label='Image Output') |
|
with gr.Row(): |
|
gr.Examples( |
|
fn=process, |
|
examples=EXAMPLES, |
|
inputs=[ |
|
checkpoint_dropdown_component, |
|
mode_dropdown_component, |
|
image_input_component, |
|
image_prompter_input_component, |
|
], |
|
outputs=[image_output_component], |
|
run_on_click=True |
|
) |
|
|
|
|
|
def on_mode_dropdown_change(text): |
|
return [ |
|
gr.Image(visible=text == MASK_GENERATION_MODE), |
|
ImagePrompter(visible=text == BOX_PROMPT_MODE) |
|
] |
|
|
|
mode_dropdown_component.change( |
|
on_mode_dropdown_change, |
|
inputs=[mode_dropdown_component], |
|
outputs=[ |
|
image_input_component, |
|
image_prompter_input_component |
|
] |
|
) |
|
submit_button_component.click( |
|
fn=process, |
|
inputs=[ |
|
checkpoint_dropdown_component, |
|
mode_dropdown_component, |
|
image_input_component, |
|
image_prompter_input_component, |
|
], |
|
outputs=[image_output_component] |
|
) |
|
|
|
demo.launch(debug=False, show_error=True, max_threads=1) |
|
|