osanseviero commited on
Commit
1d53eef
1 Parent(s): c5d2104

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.system("git clone https://github.com/thohemp/6DRepNet")
3
+ import sys
4
+ sys.path.append("frame-interpolation")
5
+
6
+ from model import SixDRepNet
7
+ import math
8
+ import re
9
+ from matplotlib import pyplot as plt
10
+ import sys
11
+ import os
12
+
13
+ import numpy as np
14
+ import cv2
15
+ import matplotlib.pyplot as plt
16
+ from numpy.lib.function_base import _quantile_unchecked
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.utils.data import DataLoader
21
+ from torchvision import transforms
22
+ import torchvision
23
+ import torch.nn.functional as F
24
+ import utils
25
+ import matplotlib
26
+ from PIL import Image
27
+ import time
28
+ from face_detection import RetinaFace
29
+ from huggingface_hub import hf_hub_download
30
+
31
+ snapshot_path = hf_hub_download(repo_id="osanseviero/6DRepNet_300W_LP_AFLW2000", filename="model.pth")
32
+
33
+ model = SixDRepNet(backbone_name='RepVGG-B1g2',
34
+ backbone_file='',
35
+ deploy=True,
36
+ pretrained=False)
37
+
38
+ detector = RetinaFace()
39
+ saved_state_dict = torch.load(os.path.join(
40
+ snapshot_path), map_location='cpu')
41
+
42
+ if 'model_state_dict' in saved_state_dict:
43
+ model.load_state_dict(saved_state_dict['model_state_dict'])
44
+ else:
45
+ model.load_state_dict(saved_state_dict)
46
+ model.eval()
47
+
48
+
49
+ def predict(img):
50
+ faces = detector(frame)
51
+ for box, landmarks, score in faces:
52
+ # Print the location of each face in this image
53
+ if score < .95:
54
+ continue
55
+ x_min = int(box[0])
56
+ y_min = int(box[1])
57
+ x_max = int(box[2])
58
+ y_max = int(box[3])
59
+ bbox_width = abs(x_max - x_min)
60
+ bbox_height = abs(y_max - y_min)
61
+
62
+ x_min = max(0,x_min-int(0.2*bbox_height))
63
+ y_min = max(0,y_min-int(0.2*bbox_width))
64
+ x_max = x_max+int(0.2*bbox_height)
65
+ y_max = y_max+int(0.2*bbox_width)
66
+
67
+ img = frame[y_min:y_max,x_min:x_max]
68
+ img = cv2.resize(img, (244, 244))/255.0
69
+ img = img.transpose(2, 0, 1)
70
+ img = torch.from_numpy(img).type(torch.FloatTensor)
71
+ img = torch.Tensor(img)
72
+ img=img.unsqueeze(0)
73
+
74
+ R_pred = model(img)
75
+ euler = utils.compute_euler_angles_from_rotation_matrices(
76
+ R_pred)*180/np.pi
77
+ p_pred_deg = euler[:, 0].cpu()
78
+ y_pred_deg = euler[:, 1].cpu()
79
+ r_pred_deg = euler[:, 2].cpu()
80
+ utils.plot_pose_cube(frame, y_pred_deg, p_pred_deg, r_pred_deg, x_min + int(.5*(x_max-x_min)), y_min + int(.5*(y_max-y_min)), size = bbox_width)
81
+
82
+ return img
83
+
84
+
85
+ iface = gr.Interface(
86
+ fn=predict,
87
+ inputs='img',
88
+ outputs='img',
89
+ )
90
+
91
+ iface.launch()