from functools import reduce import cv2 import gradio as gr import os from PIL import Image import wget import numpy as np import mediapipe as mp import modules.scripts as scripts from modules.paths_internal import models_path from modules.ui_components import FormRow, FormGroup BaseOptions = mp.tasks.BaseOptions ImageSegmenter = mp.tasks.vision.ImageSegmenter ImageSegmenterOptions = mp.tasks.vision.ImageSegmenterOptions VisionRunningMode = mp.tasks.vision.RunningMode MASK_OPTION_0_BACKGROUND = 'background' MASK_OPTION_1_HAIR = 'hair' MASK_OPTION_2_BODY = 'body (skin)' MASK_OPTION_3_FACE = 'face (skin)' MASK_OPTION_4_CLOTHES = 'clothes' title = "A Person Mask Generator" class Script(scripts.Script): def __init__(self): self.img2img: gr.Image = None def title(self): return title def show(self, is_img2img): return scripts.AlwaysVisible def get_mediapipe_image(self, image: Image) -> mp.Image: # Convert gr.Image to NumPy array numpy_image = np.asarray(image) image_format = mp.ImageFormat.SRGB # Convert BGR to RGB (if necessary) if numpy_image.shape[-1] == 4: image_format = mp.ImageFormat.SRGBA elif numpy_image.shape[-1] == 3: image_format = mp.ImageFormat.SRGB numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_BGR2RGB) return mp.Image(image_format=image_format, data=numpy_image) def generate_mask(self, image: Image, mask_targets: list[str], mask_dilation : int) -> Image: if image is not None and len(mask_targets) > 0: model_folder_name = 'mediapipe' model_file_name = 'selfie_multiclass_256x256.tflite' model_folder_path = os.path.join(models_path, model_folder_name) if not models_path.endswith(model_folder_name) else models_path os.makedirs(model_folder_path, exist_ok=True) model_path = os.path.join(model_folder_path, model_file_name) model_url = 'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/selfie_multiclass_256x256.tflite' if not os.path.exists(model_path): print(f"Downloading 'selfie_multiclass_256x256.tflite' model") wget.download(model_url, model_path) options = ImageSegmenterOptions(base_options=BaseOptions(model_asset_path=model_path), running_mode=VisionRunningMode.IMAGE, output_category_mask=True) # Create the image segmenter with ImageSegmenter.create_from_options(options) as segmenter: # Retrieve the masks for the segmented image media_pipe_image = self.get_mediapipe_image(image=image) segmented_masks = segmenter.segment(media_pipe_image) masks = [] for i, target in enumerate(mask_targets): # https://developers.google.com/mediapipe/solutions/vision/image_segmenter#multiclass-model # 0 - background # 1 - hair # 2 - body - skin # 3 - face - skin # 4 - clothes # 5 - others(accessories) mask_index = 0 if target == MASK_OPTION_1_HAIR: mask_index = 1 if target == MASK_OPTION_2_BODY: mask_index = 2 if target == MASK_OPTION_3_FACE: mask_index = 3 if target == MASK_OPTION_4_CLOTHES: mask_index = 4 masks.append(segmented_masks.confidence_masks[mask_index]) image_data = media_pipe_image.numpy_view() image_shape = image_data.shape # convert the image shape from "rgb" to "rgba" aka add the alpha channel if image_shape[-1] == 3: image_shape = (image_shape[0], image_shape[1], 4) mask_background_array = np.zeros(image_shape, dtype=np.uint8) mask_background_array[:] = (0, 0, 0, 255) mask_foreground_array = np.zeros(image_shape, dtype=np.uint8) mask_foreground_array[:] = (255, 255, 255, 255) mask_arrays = [] for i, mask in enumerate(masks): condition = np.stack((mask.numpy_view(),) * image_shape[-1], axis=-1) > 0.25 mask_array = np.where(condition, mask_foreground_array, mask_background_array) mask_arrays.append(mask_array) # Merge our masks taking the maximum from each merged_mask_arrays = reduce(np.maximum, mask_arrays) # Dilate or erode the mask if mask_dilation > 0: merged_mask_arrays = cv2.dilate(merged_mask_arrays, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*mask_dilation + 1, 2*mask_dilation + 1), (mask_dilation, mask_dilation))) elif mask_dilation < 0: merged_mask_arrays = cv2.erode(merged_mask_arrays, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*mask_dilation + 1, 2*mask_dilation + 1), (mask_dilation, mask_dilation))) # Create the image mask_image = Image.fromarray(merged_mask_arrays) return mask_image else: return None # How the script's is displayed in the UI. See https://gradio.app/docs/#components # for the different UI components you can use and how to create them. # Most UI components can return a value, such as a boolean for a checkbox. # The returned values are passed to the run method as parameters. def ui(self, is_img2img): if is_img2img: with gr.Accordion(title, open=False, elem_id="a_person_mask_generator_accordion") as accordion: with gr.Blocks(): with gr.Row(): with gr.Column(): enabled = gr.Checkbox( label="Enable", value=False, elem_id=f"a_person_mask_generator_enable_checkbox", ) mask_targets = gr.Dropdown( label="Mask", multiselect=True, choices=[ MASK_OPTION_0_BACKGROUND, MASK_OPTION_1_HAIR, MASK_OPTION_2_BODY, MASK_OPTION_3_FACE, MASK_OPTION_4_CLOTHES, ], value=[MASK_OPTION_3_FACE], elem_id="a_person_mask_generator_mask_dropdown", interactive=True, ) gr.HTML("
") preview_button = gr.Button(value="Preview Mask *", elem_id="person_mask_generator_preview_button") gr.HTML("