import cv2 | |
import streamlit as st | |
import tempfile | |
import numpy as np | |
from face_detection import FaceDetector | |
from mark_detection import MarkDetector | |
from pose_estimation import PoseEstimator | |
from utils import refine | |
from PIL import Image | |
st.title("Head Pose Estimation") | |
st.text("Just a heads up (pun intended)... The code used for this space is largely borrowed from Slightly altered to fit image needs and make it work on huggingface.") | |
# Choose between Image or Video file upload | |
file_type = st.selectbox("Choose the type of file you want to upload", ("Image", "Video")) | |
uploaded_file = st.file_uploader( | |
"Upload an image or video file of your face", | |
type=["jpg", "jpeg", "png", "mp4", "mov", "avi", "mkv"] | |
) | |
# Display placeholder for real-time video output | |
FRAME_WINDOW = st.image([]) | |
if uploaded_file is not None: | |
# Video processing | |
if file_type == "Video": | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write( | |
cap = cv2.VideoCapture( | |
st.write(f"Video source: {}") | |
# Getting frame sizes | |
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
# Initialize face detection, landmark detection, and pose estimation models | |
face_detector = FaceDetector("assets/face_detector.onnx") | |
mark_detector = MarkDetector("assets/face_landmarks.onnx") | |
pose_estimator = PoseEstimator(frame_width, frame_height) | |
# Process each frame | |
while cap.isOpened(): | |
ret, frame = | |
if not ret: | |
break | |
# Step 1: Detect faces in the frame | |
faces, _ = face_detector.detect(frame, 0.7) | |
# If a face is detected, proceed with pose estimation | |
if len(faces) > 0: | |
# Detect landmarks for the first face | |
face = refine(faces, frame_width, frame_height, 0.15)[0] | |
x1, y1, x2, y2 = face[:4].astype(int) | |
patch = frame[y1:y2, x1:x2] | |
# Run landmark detection and convert local face area to global image | |
marks = mark_detector.detect([patch])[0].reshape([68, 2]) | |
marks *= (x2 - x1) | |
marks[:, 0] += x1 | |
marks[:, 1] += y1 | |
# Pose estimation with the detected landmarks | |
pose = pose_estimator.solve(marks) | |
# Draw the pose on the frame | |
pose_estimator.visualize(frame, pose, color=(0, 255, 0)) | |
# Convert frame to RGB for Streamlit display | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
FRAME_WINDOW.image(frame_rgb) | |
cap.release() | |
# Image processing | |
elif file_type == "Image": | |
# Load and process uploaded image | |
image = np.array( | |
frame_height, frame_width, _ = image.shape | |
# Initialize models for detection and pose estimation | |
face_detector = FaceDetector("assets/face_detector.onnx") | |
mark_detector = MarkDetector("assets/face_landmarks.onnx") | |
pose_estimator = PoseEstimator(frame_width, frame_height) | |
# Detect face and landmarks | |
faces, _ = face_detector.detect(image, 0.7) | |
if len(faces) > 0: | |
face = refine(faces, frame_width, frame_height, 0.15)[0] | |
x1, y1, x2, y2 = face[:4].astype(int) | |
patch = image[y1:y2, x1:x2] | |
# Detect landmarks and map them to global image coordinates | |
marks = mark_detector.detect([patch])[0].reshape([68, 2]) | |
marks *= (x2 - x1) | |
marks[:, 0] += x1 | |
marks[:, 1] += y1 | |
# Estimate pose and visualize on image | |
pose = pose_estimator.solve(marks) | |
pose_estimator.visualize(image, pose, color=(0, 255, 0)) | |
# Convert image to RGB and display in Streamlit | |
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
st.image(image_rgb, caption="Pose Estimated Image", use_column_width=True) | |