|
import logging |
|
import queue |
|
from pathlib import Path |
|
from typing import List, NamedTuple |
|
|
|
import av |
|
import cv2 |
|
import numpy as np |
|
import streamlit as st |
|
from streamlit_webrtc import WebRtcMode, webrtc_streamer |
|
|
|
from utils.download import download_file |
|
from utils.turn import get_ice_servers |
|
|
|
from mtcnn import MTCNN |
|
from PIL import Image, ImageDraw |
|
from transformers import pipeline |
|
|
|
|
|
|
|
emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression") |
|
|
|
img_container = {"webcam": None, |
|
"analyzed": None} |
|
|
|
|
|
mtcnn = MTCNN() |
|
|
|
HERE = Path(__file__).parent |
|
ROOT = HERE.parent |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class Detection(NamedTuple): |
|
class_id: int |
|
label: str |
|
score: float |
|
box: np.ndarray |
|
|
|
|
|
|
|
|
|
|
|
result_queue: "queue.Queue[List[Detection]]" = queue.Queue() |
|
|
|
|
|
def analyze_sentiment(face): |
|
|
|
rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) |
|
|
|
pil_image = Image.fromarray(rgb_face) |
|
|
|
results = emotion_pipeline(pil_image) |
|
|
|
dominant_emotion = max(results, key=lambda x: x['score'])['label'] |
|
return dominant_emotion |
|
|
|
TEXT_SIZE = 1 |
|
LINE_SIZE = 2 |
|
|
|
|
|
def detect_and_draw_faces(frame): |
|
|
|
results = mtcnn.detect_faces(frame) |
|
|
|
|
|
for result in results: |
|
x, y, w, h = result['box'] |
|
face = frame[y:y+h, x:x+w] |
|
sentiment = analyze_sentiment(face) |
|
cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), LINE_SIZE) |
|
|
|
|
|
text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[0] |
|
text_x = x |
|
text_y = y - 10 |
|
background_tl = (text_x, text_y - text_size[1]) |
|
background_br = (text_x + text_size[0], text_y + 5) |
|
|
|
|
|
cv2.rectangle(frame, background_tl, background_br, (0, 0, 0), cv2.FILLED) |
|
|
|
cv2.putText(frame, sentiment, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, (255, 255, 255), 2) |
|
|
|
result_queue.put(results) |
|
return frame |
|
|
|
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: |
|
img = frame.to_ndarray(format="bgr24") |
|
img_container["webcam"] = img |
|
frame_with_boxes = detect_and_draw_faces(img.copy()) |
|
img_container["analyzed"] = frame_with_boxes |
|
|
|
return frame |
|
|
|
|
|
ice_servers = get_ice_servers() |
|
|
|
|
|
st.markdown( |
|
""" |
|
<style> |
|
.main { |
|
background-color: #F7F7F7; |
|
padding: 2rem; |
|
} |
|
h1, h2, h3 { |
|
color: #333333; |
|
font-family: 'Arial', sans-serif; |
|
} |
|
h1 { |
|
font-weight: 700; |
|
font-size: 2.5rem; |
|
} |
|
h2 { |
|
font-weight: 600; |
|
font-size: 2rem; |
|
} |
|
h3 { |
|
font-weight: 500; |
|
font-size: 1.5rem; |
|
} |
|
.stButton button { |
|
background-color: #E60012; |
|
color: white; |
|
border-radius: 5px; |
|
font-size: 16px; |
|
padding: 0.5rem 1rem; |
|
} |
|
</style> |
|
""", |
|
unsafe_allow_html=True |
|
) |
|
|
|
st.title("Computer Vision Test Lab") |
|
st.subheader("Facial Sentiment Analysis") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.header("Input Stream") |
|
st.subheader("Webcam") |
|
webrtc_ctx = webrtc_streamer( |
|
key="object-detection", |
|
mode=WebRtcMode.SENDRECV, |
|
rtc_configuration=ice_servers, |
|
video_frame_callback=video_frame_callback, |
|
media_stream_constraints={"video": True, "audio": False}, |
|
async_processing=True, |
|
) |
|
|
|
with col2: |
|
st.header("Analysis") |
|
st.subheader("Input Frame") |
|
input_placeholder = st.empty() |
|
st.subheader("Output Frame") |
|
output_placeholder = st.empty() |
|
|
|
if webrtc_ctx.state.playing: |
|
if st.checkbox("Show the detected labels", value=True): |
|
labels_placeholder = st.empty() |
|
|
|
|
|
|
|
|
|
|
|
while True: |
|
result = result_queue.get() |
|
labels_placeholder.table(result) |
|
|
|
img = img_container["webcam"] |
|
frame_with_boxes = img_container["analyzed"] |
|
|
|
if img is None: |
|
continue |
|
|
|
input_placeholder.image(img, channels="BGR") |
|
output_placeholder.image(frame_with_boxes, channels="BGR") |
|
|
|
|