ManhHoDinh's picture
up code
eb06a89
import os
import gradio as gr
import cv2
import pandas as pd
import random
from datetime import datetime
import firebase_admin
from firebase_admin import credentials
from firebase_admin import firestore
from ultralytics import YOLO
from tracker import Tracker
from utils import ID2LABEL, MODEL_PATH, AUTHEN_ACCOUNT, compute_color_for_labels
cred = credentials.Certificate(AUTHEN_ACCOUNT)
firebase_admin.initialize_app(cred)
db = firestore.client()
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
for j in range(10)]
detection_threshold = 0.1
model = YOLO(MODEL_PATH)
def addToDatabase(ss_id, obj_ids):
try:
new_doc = db.collection("TrafficData").document()
print(new_doc.id)
data = {
"SS_ID": ss_id,
"TF_COUNT_CAR": len(obj_ids['car']),
"TF_COUNT_MOTOBIKE": len(obj_ids['bicycle']) + len(obj_ids['motocycle']),
"TF_COUNT_OTHERS": len(obj_ids['bus']) + len(obj_ids['truck']) + len(obj_ids['other']),
"TF_ID": new_doc.id,
"TF_TIME": datetime.utcnow()
}
try:
db.collection("TrafficData").document(new_doc.id).set(data)
print("Sucessfully saved to database")
except:
print("Can't upload a new data")
except:
print("Can't create a new data")
def traffic_counting(video):
obj_ids = {"person": [],
"bicycle": [],
"car": [],
"motocycle": [],
"bus": [],
"truck": [],
"other": []}
cap = cv2.VideoCapture(video)
ret, frame = cap.read()
tracker = Tracker()
while ret:
results = model.predict(frame)
for result in results:
detections = []
for r in result.boxes.data.tolist():
x1, y1, x2, y2, score, class_id = r
x1 = int(x1)
x2 = int(x2)
y1 = int(y1)
y2 = int(y2)
class_id = int(class_id)
if score > detection_threshold:
detections.append([x1, y1, x2, y2, class_id, score])
tracker.update(frame, detections)
for track in tracker.tracks:
bbox = track.bbox
x1, y1, x2, y2 = bbox
track_id = track.track_id
class_id = track.class_id
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (compute_color_for_labels(class_id)), 3)
label_name = ID2LABEL[class_id] if class_id in ID2LABEL.keys() else "other"
if track_id not in obj_ids[label_name]:
obj_ids[label_name].append(track_id)
cv2.putText(frame,f"{label_name}-{track_id}",
(int(x1) + 5, int(y1) - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1, cv2.LINE_AA )
# Count each type of traffic
output_data = {key: len(value) for key, value in obj_ids.items()}
df = pd.DataFrame(list(output_data.items()), columns=['Type', 'Number'])
yield frame, df
ret, frame = cap.read()
cap.release()
cv2.destroyAllWindows()
video_path = video.replace("\\", "/")
# addToDatabase(video_path.split("/")[-1][:-4], obj_ids)
# input_video = gr.Video(label="Input Video")
# output_video = gr.outputs.Video(label="Processing Video")
# output_data = gr.Dataframe(interactive=False, label="Traffic's Frequency")
# demo = gr.Interface(traffic_counting,
# inputs=input_video,
# outputs=[output_video, output_data],
# examples=[os.path.join('video', x) for x in os.listdir('video') if x != ".gitkeep"],
# allow_flagging='never'
# )
def traffic_detection(image):
results = model.predict(image)
detections = []
obj_ids = {"person": [],
"bicycle": [],
"car": [],
"motocycle": [],
"bus": [],
"truck": [],
"other": []}
for result in results:
for r in result.boxes.data.tolist():
x1, y1, x2, y2, score, class_id = r
x1 = int(x1)
x2 = int(x2)
y1 = int(y1)
y2 = int(y2)
class_id = int(class_id)
if score > detection_threshold:
detections.append([x1, y1, x2, y2, class_id, score])
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (compute_color_for_labels(class_id)), 1)
label_name = ID2LABEL[class_id] if class_id in ID2LABEL.keys() else "other"
cv2.putText(image,f"{label_name}",
(int(x1) + 5, int(y1) - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.3,compute_color_for_labels(class_id), 1, cv2.LINE_AA )
# Count each type of traffic
output_data = {key: len(value) for key, value in obj_ids.items()}
df = pd.DataFrame(list(output_data.items()), columns=['Type', 'Number'])
yield image, df
# Input is a image
input_image = gr.Image(label="Input Image")
output_image = gr.Image(type="filepath", label="Processing Image")
output_data = gr.Dataframe(interactive=False, label="Traffic's Frequency")
demo = gr.Interface(traffic_detection,
inputs=input_image,
outputs=[output_image, output_data],
examples=[os.path.join('image', x) for x in os.listdir('image') if x != ".gitkeep"],
allow_flagging='never'
)
if __name__ == "__main__":
demo.queue()
demo.launch(share= False)