from typing import Optional import gradio as gr import numpy as np import supervision as sv import torch from PIL import Image from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from utils.models import load_models, CHECKPOINT_NAMES MARKDOWN = """ # Segment Anything Model 2 🔥
Segment Anything Model 2 (SAM 2) is a foundation model designed to address promptable visual segmentation in both images and videos. The model extends its functionality to video by treating images as single-frame videos. Its design, a simple transformer architecture with streaming memory, enables real-time video processing. A model-in-the-loop data engine, which enhances the model and data through user interaction, was built to collect the SA-V dataset, the largest video segmentation dataset to date. SAM 2, trained on this extensive dataset, delivers robust performance across diverse tasks and visual domains. """ EXAMPLES = [ ["tiny", "https://media.roboflow.com/notebooks/examples/dog-2.jpeg"], ["small", "https://media.roboflow.com/notebooks/examples/dog-3.jpeg"], ["large", "https://media.roboflow.com/notebooks/examples/dog-3.jpeg"], ] DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') MASK_ANNOTATOR = sv.MaskAnnotator(color_lookup=sv.ColorLookup.INDEX) MODELS = load_models(device=DEVICE) def process(checkpoint_dropdown, image_input) -> Optional[Image.Image]: sam2_model = MODELS[checkpoint_dropdown] mask_generator = SAM2AutomaticMaskGenerator(sam2_model) image = np.array(image_input.convert("RGB")) sam_result = mask_generator.generate(image) detections = sv.Detections.from_sam(sam_result=sam_result) return MASK_ANNOTATOR.annotate(scene=image_input, detections=detections) with gr.Blocks() as demo: gr.Markdown(MARKDOWN) with gr.Row(): checkpoint_dropdown_component = gr.Dropdown( choices=CHECKPOINT_NAMES, value=CHECKPOINT_NAMES[0], label="Checkpoint", info="Select a SAM2 checkpoint to use.", interactive=True ) with gr.Row(): with gr.Column(): image_input_component = gr.Image(type='pil', label='Upload image') submit_button_component = gr.Button(value='Submit', variant='primary') with gr.Column(): image_output_component = gr.Image(type='pil', label='Image Output') with gr.Row(): gr.Examples( fn=process, examples=EXAMPLES, inputs=[checkpoint_dropdown_component, image_input_component], outputs=[image_output_component], run_on_click=True ) submit_button_component.click( fn=process, inputs=[checkpoint_dropdown_component, image_input_component], outputs=[image_output_component] ) demo.launch(debug=False, show_error=True, max_threads=1)