Spaces:
Build error
Build error
import logging | |
import time | |
import cv2 | |
import numpy as np | |
from .center_crop import center_crop | |
from .face_detector import FaceDetector | |
class VSNetModelPipeline: | |
def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256): | |
self.background_resize = background_resize | |
self.no_detected_resize = no_detected_resize | |
self.model = model | |
self.face_detector = face_detector | |
self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size, power=1 / 4) | |
def create_circular_mask(h, w, center=None, power=None): | |
if center is None: # use the middle of the image | |
center = (int(w / 2), int(h / 2)) | |
Y, X = np.ogrid[:h, :w] | |
dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) | |
dist_from_center = np.clip(dist_from_center, a_min=0, a_max=max(h / 2, w / 2)) | |
dist_from_center = 1 - dist_from_center / np.max(dist_from_center) | |
if power is not None: | |
dist_from_center = np.power(dist_from_center, power) | |
dist_from_center = np.stack([dist_from_center] * 3, axis=2) | |
# mask = dist_from_center <= radius | |
return dist_from_center | |
def resize_size(image, size=720, always_apply=True): | |
h, w, c = np.shape(image) | |
if min(h, w) > size or always_apply: | |
if h < w: | |
h, w = int(size * h / w), size | |
else: | |
h, w = size, int(size * w / h) | |
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) | |
return image | |
def normalize(self, img): | |
img = img.astype(np.float32) / 255 * 2 - 1 | |
return img | |
def denormalize(self, img): | |
return (img + 1) / 2 | |
def divide_crop(self, img, must_divided=32): | |
h, w, _ = img.shape | |
h = h // must_divided * must_divided | |
w = w // must_divided * must_divided | |
img = center_crop(img, h, w) | |
return img | |
def merge_crops(self, faces_imgs, crops, full_image): | |
for face, crop in zip(faces_imgs, crops): | |
x1, y1, x2, y2 = crop | |
W, H = x2 - x1, y2 - y1 | |
result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR) | |
face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR) | |
input_face = full_image[y1: y2, x1: x2] | |
full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8) | |
return full_image | |
def __call__(self, img): | |
return self.process_image(img) | |
def process_image(self, img): | |
img = self.resize_size(img, size=self.background_resize) | |
img = self.divide_crop(img) | |
face_crops, coords = self.face_detector(img) | |
if len(face_crops) > 0: | |
start_time = time.time() | |
faces = self.normalize(face_crops) | |
faces = faces.transpose(0, 3, 1, 2) | |
out_faces = self.model(faces) | |
out_faces = self.denormalize(out_faces) | |
out_faces = out_faces.transpose(0, 2, 3, 1) | |
out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8) | |
end_time = time.time() | |
logging.info(f'Face FPS {1 / (end_time - start_time)}') | |
else: | |
out_faces = [] | |
img = self.resize_size(img, size=self.no_detected_resize) | |
img = self.divide_crop(img) | |
start_time = time.time() | |
full_image = self.normalize(img) | |
full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2) | |
full_image = self.model(full_image) | |
full_image = self.denormalize(full_image) | |
full_image = full_image.transpose(0, 2, 3, 1) | |
full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8) | |
end_time = time.time() | |
logging.info(f'Background FPS {1 / (end_time - start_time)}') | |
result_image = self.merge_crops(out_faces, coords, full_image[0]) | |
return result_image | |