Spaces:
Runtime error
Runtime error
File size: 2,999 Bytes
1d53eef fd743d2 1d53eef 91fd28c 1d53eef 7d250fc 1d53eef 91fd28c 2d9f18f d3201bf 91fd28c 1d53eef 91fd28c 1d53eef 91fd28c 1d53eef 91fd28c 1d53eef be6ee40 1d53eef 91fd28c 1d53eef 91fd28c 1601266 02bd59f ee24a3d 02bd59f 91f2be0 6f31a5e 3fc03bd 91f2be0 1d53eef 91fd28c 02bd59f 91f2be0 1d53eef 02bd59f 1d53eef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import os
os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master")
os.system("git clone https://github.com/thohemp/6DRepNet")
import sys
sys.path.append("6DRepNet")
import numpy as np
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from face_detection import RetinaFace
from model import SixDRepNet
import utils
import cv2
from PIL import Image
snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth")
model = SixDRepNet(backbone_name='RepVGG-B1g2',
backbone_file='',
deploy=True,
pretrained=False)
detector = RetinaFace(0)
saved_state_dict = torch.load(os.path.join(
snapshot_path), map_location='cpu')
if 'model_state_dict' in saved_state_dict:
model.load_state_dict(saved_state_dict['model_state_dict'])
else:
model.load_state_dict(saved_state_dict)
model.cuda(0)
model.eval()
def predict(frame):
faces = detector(frame)
for box, landmarks, score in faces:
# Print the location of each face in this image
if score < .95:
continue
x_min = int(box[0])
y_min = int(box[1])
x_max = int(box[2])
y_max = int(box[3])
bbox_width = abs(x_max - x_min)
bbox_height = abs(y_max - y_min)
x_min = max(0,x_min-int(0.2*bbox_height))
y_min = max(0,y_min-int(0.2*bbox_width))
x_max = x_max+int(0.2*bbox_height)
y_max = y_max+int(0.2*bbox_width)
img = frame[y_min:y_max,x_min:x_max]
img = cv2.resize(img, (244, 244))/255.0
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).type(torch.FloatTensor)
img = torch.Tensor(img).cuda(0)
img=img.unsqueeze(0)
R_pred = model(img)
euler = utils.compute_euler_angles_from_rotation_matrices(
R_pred)*180/np.pi
p_pred_deg = euler[:, 0].cpu()
y_pred_deg = euler[:, 1].cpu()
r_pred_deg = euler[:, 2].cpu()
return utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width)
title = "6D Rotation Representation for Unconstrained Head Pose Estimation"
description = "Gradio demo for 6DRepNet. To use it, simply click the camera picture. Read more at the links below."
article = "<div style='text-align: center;'><a href='https://github.com/thohemp/6DRepNet' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2202.12555' target='_blank'>Paper</a></div>"
image_flip_css = """
.input-image .image-preview img{
-webkit-transform: scaleX(-1);
transform: scaleX(-1) !important;
}
.output-image img {
-webkit-transform: scaleX(-1);
transform: scaleX(-1) !important;
}
"""
iface = gr.Interface(
fn=predict,
inputs=gr.inputs.Image(label="Input Image", source="webcam"),
outputs='image',
live=True,
title=title,
description=description,
article=article,
css = image_flip_css
)
iface.launch() |