JoJo_Style_Transfer / inference /model_pipeline.py
Podtekatel's picture
Fix error for no objects
ef164a1
raw
history blame
4.04 kB
import logging
import time
import cv2
import numpy as np
from .center_crop import center_crop
from .face_detector import FaceDetector
class VSNetModelPipeline:
def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256):
self.background_resize = background_resize
self.no_detected_resize = no_detected_resize
self.model = model
self.face_detector = face_detector
self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size, power=1 / 4)
@staticmethod
def create_circular_mask(h, w, center=None, power=None):
if center is None: # use the middle of the image
center = (int(w / 2), int(h / 2))
Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
dist_from_center = np.clip(dist_from_center, a_min=0, a_max=max(h / 2, w / 2))
dist_from_center = 1 - dist_from_center / np.max(dist_from_center)
if power is not None:
dist_from_center = np.power(dist_from_center, power)
dist_from_center = np.stack([dist_from_center] * 3, axis=2)
# mask = dist_from_center <= radius
return dist_from_center
@staticmethod
def resize_size(image, size=720, always_apply=True):
h, w, c = np.shape(image)
if min(h, w) > size or always_apply:
if h < w:
h, w = int(size * h / w), size
else:
h, w = size, int(size * w / h)
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
return image
def normalize(self, img):
img = img.astype(np.float32) / 255 * 2 - 1
return img
def denormalize(self, img):
return (img + 1) / 2
def divide_crop(self, img, must_divided=32):
h, w, _ = img.shape
h = h // must_divided * must_divided
w = w // must_divided * must_divided
img = center_crop(img, h, w)
return img
def merge_crops(self, faces_imgs, crops, full_image):
for face, crop in zip(faces_imgs, crops):
x1, y1, x2, y2 = crop
W, H = x2 - x1, y2 - y1
result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR)
face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR)
input_face = full_image[y1: y2, x1: x2]
full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8)
return full_image
def __call__(self, img):
return self.process_image(img)
def process_image(self, img):
img = self.resize_size(img, size=self.background_resize)
img = self.divide_crop(img)
face_crops, coords = self.face_detector(img)
if len(face_crops) > 0:
start_time = time.time()
faces = self.normalize(face_crops)
faces = faces.transpose(0, 3, 1, 2)
out_faces = self.model(faces)
out_faces = self.denormalize(out_faces)
out_faces = out_faces.transpose(0, 2, 3, 1)
out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8)
end_time = time.time()
logging.info(f'Face FPS {1 / (end_time - start_time)}')
else:
out_faces = []
img = self.resize_size(img, size=self.no_detected_resize)
img = self.divide_crop(img)
start_time = time.time()
full_image = self.normalize(img)
full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2)
full_image = self.model(full_image)
full_image = self.denormalize(full_image)
full_image = full_image.transpose(0, 2, 3, 1)
full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8)
end_time = time.time()
logging.info(f'Background FPS {1 / (end_time - start_time)}')
result_image = self.merge_crops(out_faces, coords, full_image[0])
return result_image