Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import cv2 | |
import pandas as pd | |
import random | |
from datetime import datetime | |
import firebase_admin | |
from firebase_admin import credentials | |
from firebase_admin import firestore | |
from ultralytics import YOLO | |
from tracker import Tracker | |
from utils import ID2LABEL, MODEL_PATH, AUTHEN_ACCOUNT, compute_color_for_labels | |
cred = credentials.Certificate(AUTHEN_ACCOUNT) | |
firebase_admin.initialize_app(cred) | |
db = firestore.client() | |
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
for j in range(10)] | |
detection_threshold = 0.1 | |
model = YOLO(MODEL_PATH) | |
def addToDatabase(ss_id, obj_ids): | |
try: | |
new_doc = db.collection("TrafficData").document() | |
print(new_doc.id) | |
data = { | |
"SS_ID": ss_id, | |
"TF_COUNT_CAR": len(obj_ids['car']), | |
"TF_COUNT_MOTOBIKE": len(obj_ids['bicycle']) + len(obj_ids['motocycle']), | |
"TF_COUNT_OTHERS": len(obj_ids['bus']) + len(obj_ids['truck']) + len(obj_ids['other']), | |
"TF_ID": new_doc.id, | |
"TF_TIME": datetime.utcnow() | |
} | |
try: | |
db.collection("TrafficData").document(new_doc.id).set(data) | |
print("Sucessfully saved to database") | |
except: | |
print("Can't upload a new data") | |
except: | |
print("Can't create a new data") | |
def traffic_counting(video): | |
obj_ids = {"person": [], | |
"bicycle": [], | |
"car": [], | |
"motocycle": [], | |
"bus": [], | |
"truck": [], | |
"other": []} | |
cap = cv2.VideoCapture(video) | |
ret, frame = cap.read() | |
tracker = Tracker() | |
while ret: | |
results = model.predict(frame) | |
for result in results: | |
detections = [] | |
for r in result.boxes.data.tolist(): | |
x1, y1, x2, y2, score, class_id = r | |
x1 = int(x1) | |
x2 = int(x2) | |
y1 = int(y1) | |
y2 = int(y2) | |
class_id = int(class_id) | |
if score > detection_threshold: | |
detections.append([x1, y1, x2, y2, class_id, score]) | |
tracker.update(frame, detections) | |
for track in tracker.tracks: | |
bbox = track.bbox | |
x1, y1, x2, y2 = bbox | |
track_id = track.track_id | |
class_id = track.class_id | |
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (compute_color_for_labels(class_id)), 3) | |
label_name = ID2LABEL[class_id] if class_id in ID2LABEL.keys() else "other" | |
if track_id not in obj_ids[label_name]: | |
obj_ids[label_name].append(track_id) | |
cv2.putText(frame,f"{label_name}-{track_id}", | |
(int(x1) + 5, int(y1) - 5), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA ) | |
# Count each type of traffic | |
output_data = {key: len(value) for key, value in obj_ids.items()} | |
df = pd.DataFrame(list(output_data.items()), columns=['Type', 'Number']) | |
yield frame, df | |
ret, frame = cap.read() | |
cap.release() | |
cv2.destroyAllWindows() | |
video_path = video.replace("\\", "/") | |
# addToDatabase(video_path.split("/")[-1][:-4], obj_ids) | |
# input_video = gr.Video(label="Input Video") | |
# output_video = gr.outputs.Video(label="Processing Video") | |
# output_data = gr.Dataframe(interactive=False, label="Traffic's Frequency") | |
# demo = gr.Interface(traffic_counting, | |
# inputs=input_video, | |
# outputs=[output_video, output_data], | |
# examples=[os.path.join('video', x) for x in os.listdir('video') if x != ".gitkeep"], | |
# allow_flagging='never' | |
# ) | |
def traffic_detection(image): | |
results = model.predict(image) | |
detections = [] | |
obj_ids = {"person": [], | |
"bicycle": [], | |
"car": [], | |
"motocycle": [], | |
"bus": [], | |
"truck": [], | |
"other": []} | |
for result in results: | |
for r in result.boxes.data.tolist(): | |
x1, y1, x2, y2, score, class_id = r | |
x1 = int(x1) | |
x2 = int(x2) | |
y1 = int(y1) | |
y2 = int(y2) | |
class_id = int(class_id) | |
if score > detection_threshold: | |
detections.append([x1, y1, x2, y2, class_id, score]) | |
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (compute_color_for_labels(class_id)), 1) | |
label_name = ID2LABEL[class_id] if class_id in ID2LABEL.keys() else "other" | |
cv2.putText(image,f"{label_name}", | |
(int(x1) + 5, int(y1) - 5), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.3,compute_color_for_labels(class_id), 1, cv2.LINE_AA ) | |
# Count each type of traffic | |
output_data = {key: len(value) for key, value in obj_ids.items()} | |
df = pd.DataFrame(list(output_data.items()), columns=['Type', 'Number']) | |
yield image, df | |
# Input is a image | |
input_image = gr.Image(label="Input Image") | |
output_image = gr.Image(type="filepath", label="Processing Image") | |
output_data = gr.Dataframe(interactive=False, label="Traffic's Frequency") | |
demo = gr.Interface(traffic_detection, | |
inputs=input_image, | |
outputs=[output_image, output_data], | |
examples=[os.path.join('image', x) for x in os.listdir('image') if x != ".gitkeep"], | |
allow_flagging='never' | |
) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch(share= False) |