File size: 2,999 Bytes
1d53eef
fd743d2
1d53eef
 
 
91fd28c
1d53eef
 
7d250fc
1d53eef
 
 
91fd28c
 
 
2d9f18f
d3201bf
91fd28c
1d53eef
 
 
91fd28c
 
 
 
 
1d53eef
 
91fd28c
1d53eef
 
 
91fd28c
 
1d53eef
 
be6ee40
1d53eef
 
91fd28c
 
 
 
 
 
 
 
 
 
 
 
 
 
1d53eef
91fd28c
 
 
 
 
 
 
 
 
 
 
 
1601266
02bd59f
 
 
ee24a3d
02bd59f
91f2be0
6f31a5e
3fc03bd
 
 
 
91f2be0
 
 
 
 
1d53eef
 
 
91fd28c
 
02bd59f
 
 
 
91f2be0
1d53eef
02bd59f
1d53eef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
os.system("pip install git+https://github.com/elliottzheng/face-detection.git@master")
os.system("git clone https://github.com/thohemp/6DRepNet")

import sys
sys.path.append("6DRepNet")

import numpy as np
import gradio as gr
import torch
from huggingface_hub import hf_hub_download

from face_detection import RetinaFace
from model import SixDRepNet
import utils
import cv2
from PIL import Image

snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth")

model = SixDRepNet(backbone_name='RepVGG-B1g2',
                      backbone_file='',
                      deploy=True,
                      pretrained=False)

detector = RetinaFace(0)
saved_state_dict = torch.load(os.path.join(
        snapshot_path), map_location='cpu')
       
if 'model_state_dict' in saved_state_dict:
    model.load_state_dict(saved_state_dict['model_state_dict'])
else:
    model.load_state_dict(saved_state_dict)
model.cuda(0)
model.eval()

def predict(frame):
  faces = detector(frame)
  for box, landmarks, score in faces:
    # Print the location of each face in this image
    if score < .95:
        continue
    x_min = int(box[0])
    y_min = int(box[1])
    x_max = int(box[2])
    y_max = int(box[3])         
    bbox_width = abs(x_max - x_min)
    bbox_height = abs(y_max - y_min)

    x_min = max(0,x_min-int(0.2*bbox_height))
    y_min = max(0,y_min-int(0.2*bbox_width))
    x_max = x_max+int(0.2*bbox_height)
    y_max = y_max+int(0.2*bbox_width)

    img = frame[y_min:y_max,x_min:x_max]
    img = cv2.resize(img, (244, 244))/255.0
    img = img.transpose(2, 0, 1)
    img = torch.from_numpy(img).type(torch.FloatTensor)
    img = torch.Tensor(img).cuda(0)
    img=img.unsqueeze(0)         
    R_pred = model(img)
    euler = utils.compute_euler_angles_from_rotation_matrices(
        R_pred)*180/np.pi
    p_pred_deg = euler[:, 0].cpu()
    y_pred_deg = euler[:, 1].cpu()
    r_pred_deg = euler[:, 2].cpu()
    return utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width)
 
title = "6D Rotation Representation for Unconstrained Head Pose Estimation"
description = "Gradio demo for 6DRepNet. To use it, simply click the camera picture. Read more at the links below."
article = "<div style='text-align: center;'><a href='https://github.com/thohemp/6DRepNet' target='_blank'>Github Repo</a> | <a href='https://arxiv.org/abs/2202.12555' target='_blank'>Paper</a></div>"

image_flip_css = """
.input-image .image-preview  img{
    -webkit-transform: scaleX(-1);
    transform: scaleX(-1) !important;
}

.output-image img {
    -webkit-transform: scaleX(-1);
    transform: scaleX(-1) !important;
}
"""
  
iface = gr.Interface(
    fn=predict, 
    inputs=gr.inputs.Image(label="Input Image", source="webcam"),
    outputs='image',
    live=True,
    title=title,
    description=description,
    article=article,
    css = image_flip_css
)

iface.launch()