File size: 4,174 Bytes
5e5af12
 
45d9f6f
 
 
 
 
 
c86f31f
13420d7
 
9c57041
45d9f6f
9c57041
 
 
 
 
 
 
45d9f6f
 
9c57041
45d9f6f
 
 
 
9c57041
45d9f6f
9c57041
45d9f6f
 
 
9c57041
45d9f6f
 
 
 
9c57041
 
 
 
45d9f6f
 
9c57041
45d9f6f
 
9c57041
45d9f6f
9c57041
45d9f6f
 
 
 
9c57041
45d9f6f
 
 
 
 
9c57041
45d9f6f
 
9c57041
45d9f6f
 
9c57041
 
 
45d9f6f
9c57041
45d9f6f
9c57041
 
 
 
 
45d9f6f
9c57041
 
 
 
45d9f6f
9c57041
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import cv2
import streamlit as st
import tempfile
import numpy as np
from face_detection import FaceDetector
from mark_detection import MarkDetector
from pose_estimation import PoseEstimator
from utils import refine
from PIL import Image
st.title("Head Pose Estimation")
st.text("Just a heads up (pun intended)... The code used for this space is largely borrowed from https://github.com/yinguobing/head-pose-estimation. Slightly altered to fit image needs and make it work on huggingface.")
# Choose between Image or Video file upload
file_type = st.selectbox("Choose the type of file you want to upload", ("Image", "Video"))
uploaded_file = st.file_uploader(
    "Upload an image or video file of your face", 
    type=["jpg", "jpeg", "png", "mp4", "mov", "avi", "mkv"]
)

# Display placeholder for real-time video output
FRAME_WINDOW = st.image([])

if uploaded_file is not None:
    # Video processing
    if file_type == "Video":
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(uploaded_file.read())
        cap = cv2.VideoCapture(tfile.name)
        st.write(f"Video source: {tfile.name}")

        # Getting frame sizes
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        
        # Initialize face detection, landmark detection, and pose estimation models
        face_detector = FaceDetector("assets/face_detector.onnx")
        mark_detector = MarkDetector("assets/face_landmarks.onnx")
        pose_estimator = PoseEstimator(frame_width, frame_height)

        # Process each frame
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            # Step 1: Detect faces in the frame
            faces, _ = face_detector.detect(frame, 0.7)

            # If a face is detected, proceed with pose estimation
            if len(faces) > 0:
                # Detect landmarks for the first face
                face = refine(faces, frame_width, frame_height, 0.15)[0]
                x1, y1, x2, y2 = face[:4].astype(int)
                patch = frame[y1:y2, x1:x2]

                # Run landmark detection and convert local face area to global image
                marks = mark_detector.detect([patch])[0].reshape([68, 2])
                marks *= (x2 - x1)
                marks[:, 0] += x1
                marks[:, 1] += y1

                # Pose estimation with the detected landmarks
                pose = pose_estimator.solve(marks)

                # Draw the pose on the frame
                pose_estimator.visualize(frame, pose, color=(0, 255, 0))

            # Convert frame to RGB for Streamlit display
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            FRAME_WINDOW.image(frame_rgb)

        cap.release()

    # Image processing
    elif file_type == "Image":
        # Load and process uploaded image
        image = np.array(Image.open(uploaded_file))
        frame_height, frame_width, _ = image.shape

        # Initialize models for detection and pose estimation
        face_detector = FaceDetector("assets/face_detector.onnx")
        mark_detector = MarkDetector("assets/face_landmarks.onnx")
        pose_estimator = PoseEstimator(frame_width, frame_height)

        # Detect face and landmarks
        faces, _ = face_detector.detect(image, 0.7)
        if len(faces) > 0:
            face = refine(faces, frame_width, frame_height, 0.15)[0]
            x1, y1, x2, y2 = face[:4].astype(int)
            patch = image[y1:y2, x1:x2]

            # Detect landmarks and map them to global image coordinates
            marks = mark_detector.detect([patch])[0].reshape([68, 2])
            marks *= (x2 - x1)
            marks[:, 0] += x1
            marks[:, 1] += y1

            # Estimate pose and visualize on image
            pose = pose_estimator.solve(marks)
            pose_estimator.visualize(image, pose, color=(0, 255, 0))

            # Convert image to RGB and display in Streamlit
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            st.image(image_rgb, caption="Pose Estimated Image", use_column_width=True)