|
import logging |
|
import time |
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
from .center_crop import center_crop |
|
from .face_detector import FaceDetector |
|
|
|
|
|
class VSNetModelPipeline: |
|
def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256, use_cloning=True): |
|
self.background_resize = background_resize |
|
self.no_detected_resize = no_detected_resize |
|
self.model = model |
|
self.face_detector = face_detector |
|
self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size) |
|
self.use_cloning = use_cloning |
|
|
|
@staticmethod |
|
def create_circular_mask(h, w, power=None, clipping_coef=0.85): |
|
center = (int(w / 2), int(h / 2)) |
|
|
|
Y, X = np.ogrid[:h, :w] |
|
dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) |
|
print(dist_from_center.max(), dist_from_center.min()) |
|
clipping_radius = min((h - center[0]), (w - center[1])) * clipping_coef |
|
max_size = max((h - center[0]), (w - center[1])) |
|
dist_from_center[dist_from_center < clipping_radius] = clipping_radius |
|
dist_from_center[dist_from_center > max_size] = max_size |
|
max_distance, min_distance = np.max(dist_from_center), np.min(dist_from_center) |
|
dist_from_center = 1 - (dist_from_center - min_distance) / (max_distance - min_distance) |
|
if power is not None: |
|
dist_from_center = np.power(dist_from_center, power) |
|
dist_from_center = np.stack([dist_from_center] * 3, axis=2) |
|
|
|
return dist_from_center |
|
|
|
|
|
@staticmethod |
|
def resize_size(image, size=720, always_apply=True): |
|
h, w, c = np.shape(image) |
|
if min(h, w) > size or always_apply: |
|
if h < w: |
|
h, w = int(size * h / w), size |
|
else: |
|
h, w = size, int(size * w / h) |
|
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) |
|
return image |
|
|
|
def normalize(self, img): |
|
img = img.astype(np.float32) / 255 * 2 - 1 |
|
return img |
|
|
|
def denormalize(self, img): |
|
return (img + 1) / 2 |
|
|
|
def divide_crop(self, img, must_divided=32): |
|
h, w, _ = img.shape |
|
h = h // must_divided * must_divided |
|
w = w // must_divided * must_divided |
|
|
|
img = center_crop(img, h, w) |
|
return img |
|
|
|
def merge_crops(self, faces_imgs, crops, full_image): |
|
for face, crop in zip(faces_imgs, crops): |
|
x1, y1, x2, y2 = crop |
|
W, H = x2 - x1, y2 - y1 |
|
result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR) |
|
face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR) |
|
if self.use_cloning: |
|
center = round((x2 + x1) / 2), round((y2 + y1) / 2) |
|
full_image = cv2.seamlessClone(result_face, full_image, (face_mask > 0.0).astype(np.uint8) * 255, center, cv2.NORMAL_CLONE) |
|
else: |
|
input_face = full_image[y1: y2, x1: x2] |
|
full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8) |
|
return full_image |
|
|
|
def __call__(self, img): |
|
return self.process_image(img) |
|
|
|
def process_image(self, img): |
|
img = self.resize_size(img, size=self.background_resize) |
|
img = self.divide_crop(img) |
|
|
|
face_crops, coords = self.face_detector(img) |
|
|
|
if len(face_crops) > 0: |
|
start_time = time.time() |
|
faces = self.normalize(face_crops) |
|
faces = faces.transpose(0, 3, 1, 2) |
|
out_faces = self.model(faces) |
|
out_faces = self.denormalize(out_faces) |
|
out_faces = out_faces.transpose(0, 2, 3, 1) |
|
out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8) |
|
end_time = time.time() |
|
logging.info(f'Face FPS {1 / (end_time - start_time)}') |
|
else: |
|
out_faces = [] |
|
img = self.resize_size(img, size=self.no_detected_resize) |
|
img = self.divide_crop(img) |
|
|
|
start_time = time.time() |
|
full_image = self.normalize(img) |
|
full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2) |
|
full_image = self.model(full_image) |
|
full_image = self.denormalize(full_image) |
|
full_image = full_image.transpose(0, 2, 3, 1) |
|
full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8) |
|
end_time = time.time() |
|
logging.info(f'Background FPS {1 / (end_time - start_time)}') |
|
|
|
result_image = self.merge_crops(out_faces, coords, full_image[0]) |
|
return result_image |
|
|