import torch import numpy as np from PIL import Image from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection from ultralytics import YOLO from typing import Dict, List, Tuple, Union, Optional from dataclasses import dataclass @dataclass class SegmentationResult: """Data class to store segmentation results""" label: str confidence: float mask: np.ndarray bounding_box: List[int] class ObjectSegmenter: """A class for zero-shot object detection and segmentation""" def __init__(self, device: Optional[str] = None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") torch.cuda.empty_cache() self._init_models() def _init_models(self): """Initialize DINO and YOLO models""" # Grounding DINO setup self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny") self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained( "IDEA-Research/grounding-dino-tiny" ).to(self.device).eval() # YOLO setup self.yolo_model = YOLO('yolov8n-seg.pt') def segment_objects( self, image: Union[Image.Image, np.ndarray, str], objects: Union[str, List[str]], box_threshold: float = 0.4, text_threshold: float = 0.3 ) -> List[SegmentationResult]: """Segment specified objects in the image""" # Prepare image if isinstance(image, str): image = Image.open(image) elif isinstance(image, np.ndarray): image = Image.fromarray(image) if image.mode != 'RGB': image = image.convert('RGB') # Prepare text prompt if isinstance(objects, list): text_prompt = ". ".join(objects) else: text_prompt = objects if not text_prompt.endswith('.'): text_prompt += '.' # Get DINO detections dino_results = self._get_dino_detections( image, text_prompt, box_threshold, text_threshold ) # Get YOLO segmentation yolo_results = self.yolo_model(image, verbose=False)[0] # Match detections with segmentations return self._process_results(dino_results, yolo_results) @torch.no_grad() def _get_dino_detections( self, image: Image.Image, text_prompt: str, box_threshold: float, text_threshold: float ) -> dict: """Get object detections from Grounding DINO""" inputs = self.dino_processor( images=image, text=text_prompt, return_tensors="pt" ).to(self.device) outputs = self.dino_model(**inputs) results = self.dino_processor.post_process_grounded_object_detection( outputs, inputs.input_ids, box_threshold=box_threshold, text_threshold=text_threshold, target_sizes=[image.size[::-1]] )[0] return results def _process_results( self, dino_results: dict, yolo_results ) -> List[SegmentationResult]: """Match detections with segmentations and create result objects""" segmentation_results = [] for box, score, label in zip( dino_results["boxes"], dino_results["scores"], dino_results["labels"] ): box = [int(x) for x in box.tolist()] # Find best matching YOLO mask best_mask = self._find_best_mask(box, yolo_results) if best_mask is not None: result = SegmentationResult( label=label, confidence=float(score), mask=best_mask, bounding_box=box ) segmentation_results.append(result) return segmentation_results def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]: """Find best matching YOLO mask for a given bounding box""" if len(yolo_results.masks) == 0: return None best_iou = 0 best_mask = None for mask in yolo_results.masks.data: mask_np = mask.cpu().numpy() y_indices, x_indices = np.where(mask_np > 0) if len(y_indices) == 0: continue mask_box = [ x_indices.min(), y_indices.min(), x_indices.max(), y_indices.max() ] iou = self._calculate_iou(box, mask_box) if iou > best_iou: best_iou = iou best_mask = mask_np return best_mask @staticmethod def _calculate_iou(box1: List[int], box2: List[int]) -> float: """Calculate Intersection over Union between two boxes""" intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \ max(0, min(box1[3], box2[3]) - max(box1[1], box2[1])) box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) return intersection / (box1_area + box2_area - intersection) # Initialize the segmenter segmenter = ObjectSegmenter()