Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import uuid | |
import gradio as gr | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
import supervision as sv | |
from ultralyticsplus import YOLO, download_from_hub | |
hf_model_ids = ["chanelcolgate/rods-count-v1", "chanelcolgate/cab-v1"] | |
image_paths = [ | |
[image_path, "chanelcolgate/rods-cout-v1", 640, 0.6, 0.45] | |
for image_path in glob.glob("./images/*.jpg") | |
] | |
video_paths = [ | |
[video_path, "chanelcolgate/cab-v1"] | |
for video_path in glob.glob("./videos/*.mp4") | |
] | |
def get_center_of_bbox(bbox): | |
x1, y1, x2, y2 = bbox | |
return int((x1 + x2) / 2), int((y1 + y2) / 2) | |
def get_bbox_width(bbox): | |
return int(bbox[2] - bbox[0]) | |
def draw_circle(pil_image, bbox, color, id): | |
# Convert PIL image to a numpy array (OpenCV format) | |
cv_image = np.array(pil_image) | |
# Convert RGB to BGR (OpenCV format) | |
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR) | |
x_center, y_center = get_center_of_bbox(bbox) | |
width = get_bbox_width(bbox) | |
# Draw the circle on the image | |
cv2.circle( | |
cv_image, | |
center=(x_center, y_center), | |
radius=int(width * 0.5 * 0.6), | |
color=color, | |
thickness=1, | |
) | |
cv2.putText( | |
cv_image, | |
f"{id}", | |
(x_center - 6, y_center + 6), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
0.5, | |
(255, 249, 208), | |
2, | |
) | |
# Convert BGR back to RGB (PIL format) | |
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) | |
# Convert the numpy array back to a PIL Image | |
pil_image = Image.fromarray(cv_image) | |
return pil_image | |
def count_predictions( | |
image=None, | |
hf_model_id="chanelcolgate/rods-count-v1", | |
image_size=640, | |
conf_threshold=0.25, | |
iou_threshold=0.45, | |
): | |
model_path = download_from_hub(hf_model_id) | |
model = YOLO(model_path) | |
results = model( | |
image, imgsz=image_size, conf=conf_threshold, iou=iou_threshold | |
) | |
detections = sv.Detections.from_ultralytics(results[0]) | |
for id, detection in enumerate(detections): | |
image = image.copy() | |
bbox = detection[0].tolist() | |
image = draw_circle(image, bbox, (90, 178, 255), id + 1) | |
return image, len(detections) | |
def count_across_line( | |
source_video_path=None, | |
hf_model_id="chanelcolgate/cab-v1", | |
): | |
TARGET_VIDEO_PATH = os.path.join("./", f"{uuid.uuid4()}.mp4") | |
LINE_START = sv.Point(976, 212) | |
LINE_END = sv.Point(976, 1276) | |
model_path = download_from_hub(hf_model_id) | |
model = YOLO(model_path) | |
byte_tracker = sv.ByteTrack( | |
track_thresh=0.25, track_buffer=30, match_thresh=0.8, frame_rate=30 | |
) | |
video_info = sv.VideoInfo.from_video_path(source_video_path) | |
generator = sv.get_video_frames_generator(source_video_path) | |
line_zone = sv.LineZone(start=LINE_START, end=LINE_END) | |
box_annotator = sv.BoxAnnotator(thickness=4, text_thickness=4, text_scale=2) | |
trace_annotator = sv.TraceAnnotator(thickness=4, trace_length=50) | |
line_zone_annotator = sv.LineZoneAnnotator( | |
thickness=4, text_thickness=4, text_scale=2 | |
) | |
def callback(frame: np.ndarray, index: int) -> np.ndarray: | |
results = model.predict(frame) | |
cls_names = results[0].names | |
detection = sv.Detections.from_ultralytics(results[0]) | |
detection_supervision = byte_tracker.update_with_detections(detection) | |
labels_convert = [ | |
f"#{tracker_id} {cls_names[class_id]} {confidence:0.2f}" | |
for _, _, confidence, class_id, tracker_id, _ in detection_supervision | |
] | |
annotated_frame = trace_annotator.annotate( | |
scene=frame.copy(), detections=detection_supervision | |
) | |
annotated_frame = box_annotator.annotate( | |
scene=annotated_frame, | |
detections=detection_supervision, | |
skip_label=True, | |
# labels=labels_convert, | |
) | |
# update line counter | |
line_zone.trigger(detection_supervision) | |
# return frame with box and line annotated result | |
return line_zone_annotator.annotate( | |
annotated_frame, line_counter=line_zone | |
) | |
# process the whole video | |
sv.process_video( | |
source_path=source_video_path, | |
target_path=TARGET_VIDEO_PATH, | |
callback=callback, | |
) | |
return TARGET_VIDEO_PATH, line_zone.out_count | |
def count_in_zone( | |
source_video_path=None, | |
hf_model_id="chanelcolgate/cab-v1", | |
): | |
TARGET_VIDEO_PATH = os.path.join("./", f"{uuid.uuid4()}.mp4") | |
colors = sv.ColorPalette.default() | |
polygons = [ | |
np.array([[88, 292], [748, 284], [736, 1160], [96, 1148]]), | |
np.array([[844, 240], [844, 1132], [1580, 1124], [1584, 264]]), | |
] | |
model_path = download_from_hub(hf_model_id) | |
model = YOLO(model_path) | |
byte_tracker = sv.ByteTrack( | |
track_thresh=0.25, track_buffer=30, match_thresh=0.8, frame_rate=30 | |
) | |
video_info = sv.VideoInfo.from_video_path(source_video_path) | |
generator = sv.get_video_frames_generator(source_video_path) | |
zones = [ | |
sv.PolygonZone( | |
polygon=polygon, frame_resolution_wh=video_info.resolution_wh | |
) | |
for polygon in polygons | |
] | |
zone_annotators = [ | |
sv.PolygonZoneAnnotator( | |
zone=zone, | |
color=colors.by_idx(index), | |
thickness=4, | |
text_thickness=4, | |
text_scale=2, | |
) | |
for index, zone in enumerate(zones) | |
] | |
box_annotators = [ | |
sv.BoxAnnotator( | |
thickness=4, | |
text_thickness=4, | |
text_scale=2, | |
color=colors.by_idx(index), | |
) | |
for index in range(len(polygons)) | |
] | |
def callback(frame: np.ndarray, index: int) -> np.ndarray: | |
results = model.predict(frame) | |
detection = sv.Detections.from_ultralytics(results[0]) | |
detection_supervision = byte_tracker.update_with_detections(detection) | |
for zone, zone_annotator, box_annotator in zip( | |
zones, zone_annotators, box_annotators | |
): | |
zone.trigger(detections=detection_supervision) | |
frame = box_annotator.annotate( | |
scene=frame, detections=detection_supervision, skip_label=True | |
) | |
frame = zone_annotator.annotate(scene=frame) | |
return frame | |
sv.process_video( | |
source_path=source_video_path, | |
target_path=TARGET_VIDEO_PATH, | |
callback=callback, | |
) | |
return TARGET_VIDEO_PATH, [zone.current_count for zone in zones] | |
title = "Demo Counting" | |
interface_count_predictions = gr.Interface( | |
fn=count_predictions, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Dropdown(hf_model_ids), | |
gr.Slider( | |
minimum=320, maximum=1280, value=640, step=32, label="Image Size" | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.25, | |
step=0.05, | |
label="Confidence Threshold", | |
), | |
gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.45, | |
step=0.05, | |
label="IOU Threshold", | |
), | |
], | |
outputs=[gr.Image(type="pil"), gr.Textbox(show_label=False)], | |
title="Count Predictions", | |
examples=image_paths, | |
cache_examples=True if image_paths else False, | |
) | |
interface_count_across_line = gr.Interface( | |
fn=count_across_line, | |
inputs=[ | |
gr.Video(label="Input Video"), | |
gr.Dropdown(hf_model_ids), | |
], | |
outputs=[gr.Video(label="Output Video"), gr.Textbox(show_label=False)], | |
title="Count Across Line", | |
examples=video_paths, | |
cache_examples=True if video_paths else False, | |
) | |
interface_count_in_zone = gr.Interface( | |
fn=count_in_zone, | |
inputs=[gr.Video(label="Input Video"), gr.Dropdown(hf_model_ids)], | |
outputs=[gr.Video(label="Output Video"), gr.Textbox(show_label=False)], | |
title="Count in Zone", | |
examples=video_paths, | |
cache_examples=True if video_paths else False, | |
) | |
gr.TabbedInterface( | |
[ | |
interface_count_predictions, | |
interface_count_across_line, | |
interface_count_in_zone, | |
], | |
tab_names=["Count Predictions", "Count Across Line", "Count in Zone"], | |
title="Demo Counting", | |
).queue().launch() | |