demo-counting / app.py
chanelcolgate's picture
Update app.py
1787e27 verified
raw
history blame
8.24 kB
import os
import glob
import uuid
import gradio as gr
from PIL import Image
import cv2
import numpy as np
import supervision as sv
from ultralyticsplus import YOLO, download_from_hub
hf_model_ids = ["chanelcolgate/rods-count-v1", "chanelcolgate/cab-v1"]
image_paths = [
[image_path, "chanelcolgate/rods-count-v1", 640, 0.6, 0.45]
for image_path in glob.glob("./images/*.jpg")
]
video_paths = [
[video_path, "chanelcolgate/cab-v1"]
for video_path in glob.glob("./videos/*.mp4")
]
def get_center_of_bbox(bbox):
x1, y1, x2, y2 = bbox
return int((x1 + x2) / 2), int((y1 + y2) / 2)
def get_bbox_width(bbox):
return int(bbox[2] - bbox[0])
def draw_circle(pil_image, bbox, color, id):
# Convert PIL image to a numpy array (OpenCV format)
cv_image = np.array(pil_image)
# Convert RGB to BGR (OpenCV format)
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
x_center, y_center = get_center_of_bbox(bbox)
width = get_bbox_width(bbox)
# Draw the circle on the image
cv2.circle(
cv_image,
center=(x_center, y_center),
radius=int(width * 0.5 * 0.6),
color=color,
thickness=1,
)
cv2.putText(
cv_image,
f"{id}",
(x_center - 6, y_center + 6),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
(255, 249, 208),
2,
)
# Convert BGR back to RGB (PIL format)
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
# Convert the numpy array back to a PIL Image
pil_image = Image.fromarray(cv_image)
return pil_image
def count_predictions(
image=None,
hf_model_id="chanelcolgate/rods-count-v1",
image_size=640,
conf_threshold=0.25,
iou_threshold=0.45,
):
model_path = download_from_hub(hf_model_id)
model = YOLO(model_path)
results = model(
image, imgsz=image_size, conf=conf_threshold, iou=iou_threshold
)
detections = sv.Detections.from_ultralytics(results[0])
for id, detection in enumerate(detections):
image = image.copy()
bbox = detection[0].tolist()
image = draw_circle(image, bbox, (90, 178, 255), id + 1)
return image, len(detections)
def count_across_line(
source_video_path=None,
hf_model_id="chanelcolgate/cab-v1",
):
TARGET_VIDEO_PATH = os.path.join("./", f"{uuid.uuid4()}.mp4")
LINE_START = sv.Point(976, 212)
LINE_END = sv.Point(976, 1276)
model_path = download_from_hub(hf_model_id)
model = YOLO(model_path)
byte_tracker = sv.ByteTrack(
track_thresh=0.25, track_buffer=30, match_thresh=0.8, frame_rate=30
)
video_info = sv.VideoInfo.from_video_path(source_video_path)
generator = sv.get_video_frames_generator(source_video_path)
line_zone = sv.LineZone(start=LINE_START, end=LINE_END)
box_annotator = sv.BoxAnnotator(thickness=4, text_thickness=4, text_scale=2)
trace_annotator = sv.TraceAnnotator(thickness=4, trace_length=50)
line_zone_annotator = sv.LineZoneAnnotator(
thickness=4, text_thickness=4, text_scale=2
)
def callback(frame: np.ndarray, index: int) -> np.ndarray:
results = model.predict(frame)
cls_names = results[0].names
detection = sv.Detections.from_ultralytics(results[0])
detection_supervision = byte_tracker.update_with_detections(detection)
labels_convert = [
f"#{tracker_id} {cls_names[class_id]} {confidence:0.2f}"
for _, _, confidence, class_id, tracker_id, _ in detection_supervision
]
annotated_frame = trace_annotator.annotate(
scene=frame.copy(), detections=detection_supervision
)
annotated_frame = box_annotator.annotate(
scene=annotated_frame,
detections=detection_supervision,
skip_label=True,
# labels=labels_convert,
)
# update line counter
line_zone.trigger(detection_supervision)
# return frame with box and line annotated result
return line_zone_annotator.annotate(
annotated_frame, line_counter=line_zone
)
# process the whole video
sv.process_video(
source_path=source_video_path,
target_path=TARGET_VIDEO_PATH,
callback=callback,
)
return TARGET_VIDEO_PATH, line_zone.out_count
def count_in_zone(
source_video_path=None,
hf_model_id="chanelcolgate/cab-v1",
):
TARGET_VIDEO_PATH = os.path.join("./", f"{uuid.uuid4()}.mp4")
colors = sv.ColorPalette.default()
polygons = [
np.array([[88, 292], [748, 284], [736, 1160], [96, 1148]]),
np.array([[844, 240], [844, 1132], [1580, 1124], [1584, 264]]),
]
model_path = download_from_hub(hf_model_id)
model = YOLO(model_path)
byte_tracker = sv.ByteTrack(
track_thresh=0.25, track_buffer=30, match_thresh=0.8, frame_rate=30
)
video_info = sv.VideoInfo.from_video_path(source_video_path)
generator = sv.get_video_frames_generator(source_video_path)
zones = [
sv.PolygonZone(
polygon=polygon, frame_resolution_wh=video_info.resolution_wh
)
for polygon in polygons
]
zone_annotators = [
sv.PolygonZoneAnnotator(
zone=zone,
color=colors.by_idx(index),
thickness=4,
text_thickness=4,
text_scale=2,
)
for index, zone in enumerate(zones)
]
box_annotators = [
sv.BoxAnnotator(
thickness=4,
text_thickness=4,
text_scale=2,
color=colors.by_idx(index),
)
for index in range(len(polygons))
]
def callback(frame: np.ndarray, index: int) -> np.ndarray:
results = model.predict(frame)
detection = sv.Detections.from_ultralytics(results[0])
detection_supervision = byte_tracker.update_with_detections(detection)
for zone, zone_annotator, box_annotator in zip(
zones, zone_annotators, box_annotators
):
zone.trigger(detections=detection_supervision)
frame = box_annotator.annotate(
scene=frame, detections=detection_supervision, skip_label=True
)
frame = zone_annotator.annotate(scene=frame)
return frame
sv.process_video(
source_path=source_video_path,
target_path=TARGET_VIDEO_PATH,
callback=callback,
)
return TARGET_VIDEO_PATH, [zone.current_count for zone in zones]
title = "Demo Counting"
interface_count_predictions = gr.Interface(
fn=count_predictions,
inputs=[
gr.Image(type="pil"),
gr.Dropdown(hf_model_ids),
gr.Slider(
minimum=320, maximum=1280, value=640, step=32, label="Image Size"
),
gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.25,
step=0.05,
label="Confidence Threshold",
),
gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.45,
step=0.05,
label="IOU Threshold",
),
],
outputs=[gr.Image(type="pil"), gr.Textbox(show_label=False)],
title="Count Predictions",
examples=image_paths,
cache_examples=True if image_paths else False,
)
interface_count_across_line = gr.Interface(
fn=count_across_line,
inputs=[
gr.Video(label="Input Video"),
gr.Dropdown(hf_model_ids),
],
outputs=[gr.Video(label="Output Video"), gr.Textbox(show_label=False)],
title="Count Across Line",
examples=video_paths,
cache_examples=True if video_paths else False,
)
interface_count_in_zone = gr.Interface(
fn=count_in_zone,
inputs=[gr.Video(label="Input Video"), gr.Dropdown(hf_model_ids)],
outputs=[gr.Video(label="Output Video"), gr.Textbox(show_label=False)],
title="Count in Zone",
examples=video_paths,
cache_examples=True if video_paths else False,
)
gr.TabbedInterface(
[
interface_count_predictions,
interface_count_across_line,
interface_count_in_zone,
],
tab_names=["Count Predictions", "Count Across Line", "Count in Zone"],
title="Demo Counting",
).queue().launch()