File size: 7,361 Bytes
f4a41d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import os
import sys
sys.path.append(
os.path.dirname(os.path.abspath(__file__))
)
from functools import reduce
import cv2
import torch
import numpy as np
from PIL import Image
import mediapipe as mp
import folder_paths
def get_a_person_mask_generator_model_path() -> str:
model_folder_name = 'mediapipe'
model_name = 'selfie_multiclass_256x256.tflite'
model_folder_path = os.path.join(folder_paths.models_dir, model_folder_name)
model_file_path = os.path.join(model_folder_path, model_name)
if not os.path.exists(model_file_path):
import wget
model_url = f'https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_multiclass_256x256/float32/latest/{model_name}'
print(f"Downloading '{model_name}' model")
os.makedirs(model_folder_path, exist_ok=True)
wget.download(model_url, model_file_path)
return model_file_path
class APersonMaskGenerator:
def __init__(self):
# download the model if we need it
get_a_person_mask_generator_model_path()
@classmethod
def INPUT_TYPES(self):
false_widget = ("BOOLEAN", {"default": False, "label_on": "enabled", "label_off": "disabled"})
true_widget = ("BOOLEAN", {"default": True, "label_on": "enabled", "label_off": "disabled"})
return {
"required":
{
"images": ("IMAGE",),
},
"optional":
{
"face_mask": true_widget,
"background_mask": false_widget,
"hair_mask": false_widget,
"body_mask": false_widget,
"clothes_mask": false_widget,
"confidence": ("FLOAT", {"default": 0.40, "min": 0.01, "max": 1.0, "step": 0.01},),
}
}
CATEGORY = "A Person Mask Generator - David Bielejeski"
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("masks",)
FUNCTION = "generate_mask"
def get_mediapipe_image(self, image: Image) -> mp.Image:
# Convert image to NumPy array
numpy_image = np.asarray(image)
image_format = mp.ImageFormat.SRGB
# Convert BGR to RGB (if necessary)
if numpy_image.shape[-1] == 4:
image_format = mp.ImageFormat.SRGBA
elif numpy_image.shape[-1] == 3:
image_format = mp.ImageFormat.SRGB
numpy_image = cv2.cvtColor(numpy_image, cv2.COLOR_BGR2RGB)
return mp.Image(image_format=image_format, data=numpy_image)
def generate_mask(self, images, face_mask: bool, background_mask: bool, hair_mask: bool, body_mask: bool, clothes_mask: bool, confidence: float):
"""Create a segmentation mask from an image
Args:
image (torch.Tensor): The image to create the mask for.
face_mask (bool): create a mask for the background.
background_mask (bool): create a mask for the hair.
hair_mask (bool): create a mask for the body .
body_mask (bool): create a mask for the face.
clothes_mask (bool): create a mask for the clothes.
Returns:
torch.Tensor: The segmentation masks.
"""
a_person_mask_generator_model_path = get_a_person_mask_generator_model_path()
a_person_mask_generator_model_buffer = None
with open(a_person_mask_generator_model_path, "rb") as f:
a_person_mask_generator_model_buffer = f.read()
image_segmenter_base_options = mp.tasks.BaseOptions(model_asset_buffer=a_person_mask_generator_model_buffer)
options = mp.tasks.vision.ImageSegmenterOptions(
base_options=image_segmenter_base_options,
running_mode=mp.tasks.vision.RunningMode.IMAGE,
output_category_mask=True)
# Create the image segmenter
res_masks = []
with mp.tasks.vision.ImageSegmenter.create_from_options(options) as segmenter:
for image in images:
# Convert the Tensor to a PIL image
i = 255. * image.cpu().numpy()
image_pil = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
# create our foreground and background arrays for storing the mask results
mask_background_array = np.zeros((image_pil.size[0], image_pil.size[1], 4), dtype=np.uint8)
mask_background_array[:] = (0, 0, 0, 255)
mask_foreground_array = np.zeros((image_pil.size[0], image_pil.size[1], 4), dtype=np.uint8)
mask_foreground_array[:] = (255, 255, 255, 255)
# Retrieve the masks for the segmented image
media_pipe_image = self.get_mediapipe_image(image=image_pil)
segmented_masks = segmenter.segment(media_pipe_image)
# https://developers.google.com/mediapipe/solutions/vision/image_segmenter#multiclass-model
# 0 - background
# 1 - hair
# 2 - body - skin
# 3 - face - skin
# 4 - clothes
# 5 - others(accessories)
masks = []
if background_mask:
masks.append(segmented_masks.confidence_masks[0])
if hair_mask:
masks.append(segmented_masks.confidence_masks[1])
if body_mask:
masks.append(segmented_masks.confidence_masks[2])
if face_mask:
masks.append(segmented_masks.confidence_masks[3])
if clothes_mask:
masks.append(segmented_masks.confidence_masks[4])
image_data = media_pipe_image.numpy_view()
image_shape = image_data.shape
# convert the image shape from "rgb" to "rgba" aka add the alpha channel
if image_shape[-1] == 3:
image_shape = (image_shape[0], image_shape[1], 4)
mask_background_array = np.zeros(image_shape, dtype=np.uint8)
mask_background_array[:] = (0, 0, 0, 255)
mask_foreground_array = np.zeros(image_shape, dtype=np.uint8)
mask_foreground_array[:] = (255, 255, 255, 255)
mask_arrays = []
if len(masks) == 0:
mask_arrays.append(mask_background_array)
else:
for i, mask in enumerate(masks):
condition = np.stack((mask.numpy_view(),) * image_shape[-1], axis=-1) > confidence
mask_array = np.where(condition, mask_foreground_array, mask_background_array)
mask_arrays.append(mask_array)
# Merge our masks taking the maximum from each
merged_mask_arrays = reduce(np.maximum, mask_arrays)
# Create the image
mask_image = Image.fromarray(merged_mask_arrays)
# convert PIL image to tensor image
tensor_mask = mask_image.convert("RGB")
tensor_mask = np.array(tensor_mask).astype(np.float32) / 255.0
tensor_mask = torch.from_numpy(tensor_mask)[None,]
tensor_mask = tensor_mask.squeeze(3)[..., 0]
res_masks.append(tensor_mask)
return (torch.cat(res_masks, dim=0),)
|