|
import cv2 |
|
import streamlit as st |
|
import numpy as np |
|
import tempfile |
|
import os |
|
import asyncio |
|
from ultralytics import YOLO |
|
from streamlit_webrtc import VideoTransformerBase, webrtc_streamer |
|
|
|
|
|
class ObjectTrackingTransformer(VideoTransformerBase): |
|
def __init__(self): |
|
|
|
self.model = YOLO('yolov8n.pt') |
|
|
|
def transform(self, frame): |
|
|
|
frame_bgr = np.array(frame.to_image()) |
|
|
|
|
|
frame_resized = cv2.resize(frame_bgr, (640, 480)) |
|
|
|
|
|
results = self.model.track(frame_resized, persist=True) |
|
|
|
|
|
frame_annotated = results[0].plot() |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame_annotated, cv2.COLOR_BGR2RGB) |
|
|
|
return frame_rgb |
|
|
|
|
|
def main(): |
|
|
|
st.set_page_config(page_title="Object Tracking with Streamlit") |
|
|
|
|
|
st.title("Object Tracking") |
|
|
|
|
|
option = st.radio("Choose an option:", ("Live Stream", "Upload Video")) |
|
|
|
if option == "Live Stream": |
|
|
|
webrtc_streamer(key="live-stream", video_transformer_factory=ObjectTrackingTransformer) |
|
|
|
elif option == "Upload Video": |
|
|
|
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"]) |
|
|
|
|
|
start_button_pressed = st.button("Start Tracking") |
|
|
|
|
|
frame_placeholder = st.empty() |
|
|
|
|
|
stop_button_pressed = st.button("Stop") |
|
|
|
|
|
if start_button_pressed and uploaded_file is not None: |
|
|
|
track_uploaded_video(uploaded_file, stop_button_pressed, frame_placeholder) |
|
|
|
|
|
if uploaded_file: |
|
uploaded_file.close() |
|
|
|
|
|
def track_uploaded_video(video_file, stop_button, frame_placeholder): |
|
|
|
model = YOLO('yolov8n.pt') |
|
|
|
|
|
temp_video = tempfile.NamedTemporaryFile(delete=False) |
|
temp_video.write(video_file.read()) |
|
temp_video.close() |
|
|
|
|
|
cap = cv2.VideoCapture(temp_video.name) |
|
|
|
frame_count = 0 |
|
while cap.isOpened() and not stop_button: |
|
ret, frame = cap.read() |
|
|
|
if not ret: |
|
st.write("The video capture has ended.") |
|
break |
|
|
|
|
|
if frame_count % 5 == 0: |
|
|
|
frame_resized = cv2.resize(frame, (640, 480)) |
|
|
|
|
|
results = model.track(frame_resized, persist=True) |
|
|
|
|
|
frame_ = results[0].plot() |
|
|
|
|
|
frame_placeholder.image(frame_, channels="BGR") |
|
|
|
frame_count += 1 |
|
|
|
|
|
cap.release() |
|
|
|
os.remove(temp_video.name) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |