from typing import Any, Dict, List, Union import numpy as np from PIL import Image def invert_mask(mask: np.ndarray) -> np.ndarray: """Invert mask. Args: mask (np.ndarray): mask Returns: np.ndarray: inverted mask """ if mask is None or not isinstance(mask, np.ndarray): raise ValueError("Invalid mask") # return np.logical_not(mask.astype(bool)).astype(np.uint8) * 255 return np.invert(mask.astype(np.uint8)) def check_inputs_create_mask_image( mask: Union[np.ndarray, Image.Image], sam_masks: List[Dict[str, Any]], ignore_black_chk: bool = True, ) -> None: """Check create mask image inputs. Args: mask (Union[np.ndarray, Image.Image]): mask sam_masks (List[Dict[str, Any]]): SAM masks ignore_black_chk (bool): ignore black check Returns: None """ if mask is None or not isinstance(mask, (np.ndarray, Image.Image)): raise ValueError("Invalid mask") if sam_masks is None or not isinstance(sam_masks, list): raise ValueError("Invalid SAM masks") if ignore_black_chk is None or not isinstance(ignore_black_chk, bool): raise ValueError("Invalid ignore black check") def convert_mask(mask: Union[np.ndarray, Image.Image]) -> np.ndarray: """Convert mask. Args: mask (Union[np.ndarray, Image.Image]): mask Returns: np.ndarray: converted mask """ if isinstance(mask, Image.Image): mask = np.array(mask) if mask.ndim == 2: mask = mask[:, :, np.newaxis] if mask.shape[2] != 1: mask = mask[:, :, 0:1] return mask def create_mask_image( mask: Union[np.ndarray, Image.Image], sam_masks: List[Dict[str, Any]], ignore_black_chk: bool = True, ) -> np.ndarray: """Create mask image. Args: mask (Union[np.ndarray, Image.Image]): mask sam_masks (List[Dict[str, Any]]): SAM masks ignore_black_chk (bool): ignore black check Returns: np.ndarray: mask image """ check_inputs_create_mask_image(mask, sam_masks, ignore_black_chk) mask = convert_mask(mask) canvas_image = np.zeros(mask.shape, dtype=np.uint8) mask_region = np.zeros(mask.shape, dtype=np.uint8) for seg_dict in sam_masks: seg_mask = np.expand_dims(seg_dict["segmentation"].astype(np.uint8), axis=-1) canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) if (seg_mask * canvas_mask * mask).astype(bool).any(): mask_region = mask_region + (seg_mask * canvas_mask) seg_color = seg_mask * canvas_mask canvas_image = canvas_image + seg_color if not ignore_black_chk: canvas_mask = np.logical_not(canvas_image.astype(bool)).astype(np.uint8) if (canvas_mask * mask).astype(bool).any(): mask_region = mask_region + (canvas_mask) mask_region = np.tile(mask_region * 255, (1, 1, 3)) seg_image = mask_region.astype(np.uint8) return seg_image