Spaces:
Runtime error
Runtime error
import cv2 | |
import numpy as np | |
import tensorflow as tf | |
from .config import config as cfg | |
if tf.__version__ >= '2.0': | |
tf = tf.compat.v1 | |
class FaceLandmark: | |
def __init__(self, dir): | |
self.model_path = dir + '/keypoints.pb' | |
self.min_face = 60 | |
self.keypoint_num = cfg.KEYPOINTS.p_num * 2 | |
self._graph = tf.Graph() | |
with self._graph.as_default(): | |
self._graph, self._sess = self.init_model(self.model_path) | |
self.img_input = tf.get_default_graph().get_tensor_by_name( | |
'tower_0/images:0') | |
self.embeddings = tf.get_default_graph().get_tensor_by_name( | |
'tower_0/prediction:0') | |
self.training = tf.get_default_graph().get_tensor_by_name( | |
'training_flag:0') | |
self.landmark = self.embeddings[:, :self.keypoint_num] | |
self.headpose = self.embeddings[:, -7:-4] * 90. | |
self.state = tf.nn.sigmoid(self.embeddings[:, -4:]) | |
def __call__(self, img, bboxes): | |
landmark_result = [] | |
state_result = [] | |
for i, bbox in enumerate(bboxes): | |
landmark, state = self._one_shot_run(img, bbox, i) | |
if landmark is not None: | |
landmark_result.append(landmark) | |
state_result.append(state) | |
return np.array(landmark_result), np.array(state_result) | |
def simple_run(self, cropped_img): | |
with self._graph.as_default(): | |
cropped_img = np.expand_dims(cropped_img, axis=0) | |
landmark, p, states = self._sess.run( | |
[self.landmark, self.headpose, self.state], | |
feed_dict={ | |
self.img_input: cropped_img, | |
self.training: False | |
}) | |
return landmark, states | |
def _one_shot_run(self, image, bbox, i): | |
bbox_width = bbox[2] - bbox[0] | |
bbox_height = bbox[3] - bbox[1] | |
if (bbox_width <= self.min_face and bbox_height <= self.min_face): | |
return None, None | |
add = int(max(bbox_width, bbox_height)) | |
bimg = cv2.copyMakeBorder( | |
image, | |
add, | |
add, | |
add, | |
add, | |
borderType=cv2.BORDER_CONSTANT, | |
value=cfg.DATA.pixel_means) | |
bbox += add | |
one_edge = (1 + 2 * cfg.KEYPOINTS.base_extend_range[0]) * bbox_width | |
center = [(bbox[0] + bbox[2]) // 2, (bbox[1] + bbox[3]) // 2] | |
bbox[0] = center[0] - one_edge // 2 | |
bbox[1] = center[1] - one_edge // 2 | |
bbox[2] = center[0] + one_edge // 2 | |
bbox[3] = center[1] + one_edge // 2 | |
bbox = bbox.astype(np.int) | |
crop_image = bimg[bbox[1]:bbox[3], bbox[0]:bbox[2], :] | |
h, w, _ = crop_image.shape | |
crop_image = cv2.resize( | |
crop_image, | |
(cfg.KEYPOINTS.input_shape[1], cfg.KEYPOINTS.input_shape[0])) | |
crop_image = crop_image.astype(np.float32) | |
keypoints, state = self.simple_run(crop_image) | |
res = keypoints[0][:self.keypoint_num].reshape((-1, 2)) | |
res[:, 0] = res[:, 0] * w / cfg.KEYPOINTS.input_shape[1] | |
res[:, 1] = res[:, 1] * h / cfg.KEYPOINTS.input_shape[0] | |
landmark = [] | |
for _index in range(res.shape[0]): | |
x_y = res[_index] | |
landmark.append([ | |
int(x_y[0] * cfg.KEYPOINTS.input_shape[0] + bbox[0] - add), | |
int(x_y[1] * cfg.KEYPOINTS.input_shape[1] + bbox[1] - add) | |
]) | |
landmark = np.array(landmark, np.float32) | |
return landmark, state | |
def init_model(self, *args): | |
if len(args) == 1: | |
use_pb = True | |
pb_path = args[0] | |
else: | |
use_pb = False | |
meta_path = args[0] | |
restore_model_path = args[1] | |
def ini_ckpt(): | |
graph = tf.Graph() | |
graph.as_default() | |
configProto = tf.ConfigProto() | |
configProto.gpu_options.allow_growth = True | |
sess = tf.Session(config=configProto) | |
# load_model(model_path, sess) | |
saver = tf.train.import_meta_graph(meta_path) | |
saver.restore(sess, restore_model_path) | |
print('Model restred!') | |
return (graph, sess) | |
def init_pb(model_path): | |
config = tf.ConfigProto() | |
config.gpu_options.per_process_gpu_memory_fraction = 0.2 | |
compute_graph = tf.Graph() | |
compute_graph.as_default() | |
sess = tf.Session(config=config) | |
with tf.gfile.GFile(model_path, 'rb') as fid: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(fid.read()) | |
tf.import_graph_def(graph_def, name='') | |
# saver = tf.train.Saver(tf.global_variables()) | |
# saver.save(sess, save_path='./tmp.ckpt') | |
return (compute_graph, sess) | |
if use_pb: | |
model = init_pb(pb_path) | |
else: | |
model = ini_ckpt() | |
graph = model[0] | |
sess = model[1] | |
return graph, sess | |