import logging import queue from pathlib import Path from typing import List, NamedTuple import av import cv2 import numpy as np import streamlit as st from streamlit_webrtc import WebRtcMode, webrtc_streamer from utils.download import download_file from utils.turn import get_ice_servers from mtcnn import MTCNN from PIL import Image, ImageDraw from transformers import pipeline # Initialize the Hugging Face pipeline for facial emotion detection emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression") img_container = {"webcam": None, "analyzed": None} # Initialize MTCNN for face detection mtcnn = MTCNN() HERE = Path(__file__).parent ROOT = HERE.parent logger = logging.getLogger(__name__) class Detection(NamedTuple): class_id: int label: str score: float box: np.ndarray # NOTE: The callback will be called in another thread, # so use a queue here for thread-safety to pass the data # from inside to outside the callback. # TODO: A general-purpose shared state object may be more useful. result_queue: "queue.Queue[List[Detection]]" = queue.Queue() # Function to analyze sentiment def analyze_sentiment(face): # Convert face to RGB rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) # Convert the face to a PIL image pil_image = Image.fromarray(rgb_face) # Analyze sentiment using the Hugging Face pipeline results = emotion_pipeline(pil_image) # Get the dominant emotion dominant_emotion = max(results, key=lambda x: x['score'])['label'] return dominant_emotion TEXT_SIZE = 1 LINE_SIZE = 2 # Function to detect faces, analyze sentiment, and draw a red box around them def detect_and_draw_faces(frame): # Detect faces using MTCNN results = mtcnn.detect_faces(frame) # Draw on the frame for result in results: x, y, w, h = result['box'] face = frame[y:y+h, x:x+w] sentiment = analyze_sentiment(face) cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), LINE_SIZE) # Thicker red box # Calculate position for the text background and the text itself text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[0] text_x = x text_y = y - 10 background_tl = (text_x, text_y - text_size[1]) background_br = (text_x + text_size[0], text_y + 5) # Draw black rectangle as background cv2.rectangle(frame, background_tl, background_br, (0, 0, 0), cv2.FILLED) # Draw white text on top cv2.putText(frame, sentiment, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, (255, 255, 255), 2) result_queue.put(results) return frame def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: img = frame.to_ndarray(format="bgr24") img_container["webcam"] = img frame_with_boxes = detect_and_draw_faces(img.copy()) img_container["analyzed"] = frame_with_boxes return frame # return av.VideoFrame.from_ndarray(frame_with_boxes, format="bgr24") ice_servers = get_ice_servers() # Streamlit UI st.markdown( """ """, unsafe_allow_html=True ) st.title("Computer Vision Test Lab") st.subheader("Facial Sentiment Analysis") # Columns for input and output streams col1, col2 = st.columns(2) with col1: st.header("Input Stream") st.subheader("Webcam") webrtc_ctx = webrtc_streamer( key="object-detection", mode=WebRtcMode.SENDRECV, rtc_configuration=ice_servers, video_frame_callback=video_frame_callback, media_stream_constraints={"video": True, "audio": False}, async_processing=True, ) with col2: st.header("Analysis") st.subheader("Input Frame") input_placeholder = st.empty() st.subheader("Output Frame") output_placeholder = st.empty() if webrtc_ctx.state.playing: if st.checkbox("Show the detected labels", value=True): labels_placeholder = st.empty() # NOTE: The video transformation with object detection and # this loop displaying the result labels are running # in different threads asynchronously. # Then the rendered video frames and the labels displayed here # are not strictly synchronized. while True: result = result_queue.get() labels_placeholder.table(result) img = img_container["webcam"] frame_with_boxes = img_container["analyzed"] if img is None: continue input_placeholder.image(img, channels="BGR") output_placeholder.image(frame_with_boxes, channels="BGR")