|
import os |
|
import sys |
|
|
|
sys.path.append( |
|
os.path.dirname(os.path.abspath(__file__)) |
|
) |
|
|
|
from functools import reduce |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import mediapipe as mp |
|
|
|
import folder_paths |
|
|
|
|
|
def get_a_person_mask_generator_model_path() -> str: |
|
model_folder_name = 'mediapipe' |
|
model_name = 'selfie_multiclass_256x256.tflite' |
|
|
|
model_folder_path = os.path.join(folder_paths.models_dir, model_folder_name) |
|
model_file_path = os.path.join(model_folder_path, model_name) |
|
|
|
if not os.path.exists(model_file_path): |
|
import wget |
|
model_url = f'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/{model_name}' |
|
print(f"Downloading '{model_name}' model") |
|
os.makedirs(model_folder_path, exist_ok=True) |
|
wget.download(model_url, model_file_path) |
|
|
|
return model_file_path |
|
|
|
|
|
class APersonMaskGenerator: |
|
|
|
def __init__(self): |
|
|
|
get_a_person_mask_generator_model_path() |
|
|
|
@classmethod |
|
def INPUT_TYPES(self): |
|
false_widget = ("BOOLEAN", {"default": False, "label_on": "enabled", "label_off": "disabled"}) |
|
true_widget = ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"}) |
|
|
|
return { |
|
"required": |
|
{ |
|
"images": ("IMAGE",), |
|
}, |
|
"optional": |
|
{ |
|
"face_mask": true_widget, |
|
"background_mask": false_widget, |
|
"hair_mask": false_widget, |
|
"body_mask": false_widget, |
|
"clothes_mask": false_widget, |
|
"confidence": ("FLOAT", {"default": 0.40, "min": 0.01, "max": 1.0, "step": 0.01},), |
|
} |
|
} |
|
|
|
CATEGORY = "A Person Mask Generator - David Bielejeski" |
|
RETURN_TYPES = ("MASK",) |
|
RETURN_NAMES = ("masks",) |
|
|
|
FUNCTION = "generate_mask" |
|
|
|
def get_mediapipe_image(self, image: Image) -> mp.Image: |
|
|
|
numpy_image = np.asarray(image) |
|
|
|
image_format = mp.ImageFormat.SRGB |
|
|
|
|
|
if numpy_image.shape[-1] == 4: |
|
image_format = mp.ImageFormat.SRGBA |
|
elif numpy_image.shape[-1] == 3: |
|
image_format = mp.ImageFormat.SRGB |
|
numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_BGR2RGB) |
|
|
|
return mp.Image(image_format=image_format, data=numpy_image) |
|
|
|
def generate_mask(self, images, face_mask: bool, background_mask: bool, hair_mask: bool, body_mask: bool, clothes_mask: bool, confidence: float): |
|
|
|
"""Create a segmentation mask from an image |
|
|
|
Args: |
|
image (torch.Tensor): The image to create the mask for. |
|
face_mask (bool): create a mask for the background. |
|
background_mask (bool): create a mask for the hair. |
|
hair_mask (bool): create a mask for the body . |
|
body_mask (bool): create a mask for the face. |
|
clothes_mask (bool): create a mask for the clothes. |
|
|
|
Returns: |
|
torch.Tensor: The segmentation masks. |
|
""" |
|
|
|
a_person_mask_generator_model_path = get_a_person_mask_generator_model_path() |
|
a_person_mask_generator_model_buffer = None |
|
|
|
with open(a_person_mask_generator_model_path, "rb") as f: |
|
a_person_mask_generator_model_buffer = f.read() |
|
|
|
image_segmenter_base_options = mp.tasks.BaseOptions(model_asset_buffer=a_person_mask_generator_model_buffer) |
|
options = mp.tasks.vision.ImageSegmenterOptions( |
|
base_options=image_segmenter_base_options, |
|
running_mode=mp.tasks.vision.RunningMode.IMAGE, |
|
output_category_mask=True) |
|
|
|
|
|
res_masks = [] |
|
with mp.tasks.vision.ImageSegmenter.create_from_options(options) as segmenter: |
|
for image in images: |
|
|
|
i = 255. * image.cpu().numpy() |
|
image_pil = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) |
|
|
|
|
|
mask_background_array = np.zeros((image_pil.size[0], image_pil.size[1], 4), dtype=np.uint8) |
|
mask_background_array[:] = (0, 0, 0, 255) |
|
|
|
mask_foreground_array = np.zeros((image_pil.size[0], image_pil.size[1], 4), dtype=np.uint8) |
|
mask_foreground_array[:] = (255, 255, 255, 255) |
|
|
|
|
|
media_pipe_image = self.get_mediapipe_image(image=image_pil) |
|
segmented_masks = segmenter.segment(media_pipe_image) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
masks = [] |
|
if background_mask: |
|
masks.append(segmented_masks.confidence_masks[0]) |
|
if hair_mask: |
|
masks.append(segmented_masks.confidence_masks[1]) |
|
if body_mask: |
|
masks.append(segmented_masks.confidence_masks[2]) |
|
if face_mask: |
|
masks.append(segmented_masks.confidence_masks[3]) |
|
if clothes_mask: |
|
masks.append(segmented_masks.confidence_masks[4]) |
|
|
|
image_data = media_pipe_image.numpy_view() |
|
image_shape = image_data.shape |
|
|
|
|
|
if image_shape[-1] == 3: |
|
image_shape = (image_shape[0], image_shape[1], 4) |
|
|
|
mask_background_array = np.zeros(image_shape, dtype=np.uint8) |
|
mask_background_array[:] = (0, 0, 0, 255) |
|
|
|
mask_foreground_array = np.zeros(image_shape, dtype=np.uint8) |
|
mask_foreground_array[:] = (255, 255, 255, 255) |
|
|
|
mask_arrays = [] |
|
|
|
if len(masks) == 0: |
|
mask_arrays.append(mask_background_array) |
|
else: |
|
for i, mask in enumerate(masks): |
|
condition = np.stack((mask.numpy_view(),) * image_shape[-1], axis=-1) > confidence |
|
mask_array = np.where(condition, mask_foreground_array, mask_background_array) |
|
mask_arrays.append(mask_array) |
|
|
|
|
|
merged_mask_arrays = reduce(np.maximum, mask_arrays) |
|
|
|
|
|
mask_image = Image.fromarray(merged_mask_arrays) |
|
|
|
|
|
tensor_mask = mask_image.convert("RGB") |
|
tensor_mask = np.array(tensor_mask).astype(np.float32) / 255.0 |
|
tensor_mask = torch.from_numpy(tensor_mask)[None,] |
|
tensor_mask = tensor_mask.squeeze(3)[..., 0] |
|
|
|
res_masks.append(tensor_mask) |
|
|
|
return (torch.cat(res_masks, dim=0),) |
|
|