import logging import time import cv2 import numpy as np from .center_crop import center_crop from .face_detector import FaceDetector class VSNetModelPipeline: def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256): self.background_resize = background_resize self.no_detected_resize = no_detected_resize self.model = model self.face_detector = face_detector self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size, power=1 / 4) @staticmethod def create_circular_mask(h, w, center=None, power=None): if center is None: # use the middle of the image center = (int(w / 2), int(h / 2)) Y, X = np.ogrid[:h, :w] dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) dist_from_center = np.clip(dist_from_center, a_min=0, a_max=max(h / 2, w / 2)) dist_from_center = 1 - dist_from_center / np.max(dist_from_center) if power is not None: dist_from_center = np.power(dist_from_center, power) dist_from_center = np.stack([dist_from_center] * 3, axis=2) # mask = dist_from_center <= radius return dist_from_center @staticmethod def resize_size(image, size=720, always_apply=True): h, w, c = np.shape(image) if min(h, w) > size or always_apply: if h < w: h, w = int(size * h / w), size else: h, w = size, int(size * w / h) image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) return image def normalize(self, img): img = img.astype(np.float32) / 255 * 2 - 1 return img def denormalize(self, img): return (img + 1) / 2 def divide_crop(self, img, must_divided=32): h, w, _ = img.shape h = h // must_divided * must_divided w = w // must_divided * must_divided img = center_crop(img, h, w) return img def merge_crops(self, faces_imgs, crops, full_image): for face, crop in zip(faces_imgs, crops): x1, y1, x2, y2 = crop W, H = x2 - x1, y2 - y1 result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR) face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR) input_face = full_image[y1: y2, x1: x2] full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8) return full_image def __call__(self, img): return self.process_image(img) def process_image(self, img): img = self.resize_size(img, size=self.background_resize) img = self.divide_crop(img) face_crops, coords = self.face_detector(img) if len(face_crops) > 0: start_time = time.time() faces = self.normalize(face_crops) faces = faces.transpose(0, 3, 1, 2) out_faces = self.model(faces) out_faces = self.denormalize(out_faces) out_faces = out_faces.transpose(0, 2, 3, 1) out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8) end_time = time.time() logging.info(f'Face FPS {1 / (end_time - start_time)}') else: out_faces = [] img = self.resize_size(img, size=self.no_detected_resize) img = self.divide_crop(img) start_time = time.time() full_image = self.normalize(img) full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2) full_image = self.model(full_image) full_image = self.denormalize(full_image) full_image = full_image.transpose(0, 2, 3, 1) full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8) end_time = time.time() logging.info(f'Background FPS {1 / (end_time - start_time)}') result_image = self.merge_crops(out_faces, coords, full_image[0]) return result_image