|
import os |
|
import gradio as gr |
|
from ultralytics import YOLO |
|
import numpy as np |
|
import cv2 |
|
|
|
model_options = ["yolo-8n-shiprs.pt", "yolo-8s-shiprs.pt", "yolo-8m-shiprs.pt"] |
|
model_names = ["Nano", "Small", "Medium"] |
|
models = [YOLO(option) for option in model_options] |
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
def process_image(input_image, model_name, conf): |
|
if input_image is None: |
|
return None, "No image" |
|
|
|
if model_name is None: |
|
model_name = model_names[0] |
|
|
|
if conf is None: |
|
conf = 0.6 |
|
|
|
input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR) |
|
|
|
model_index = model_names.index(model_name) |
|
model = models[model_index] |
|
|
|
results = model.predict(input_image, conf=conf) |
|
class_counts = {} |
|
class_counts_str = "Class Counts:\n" |
|
|
|
for r in results: |
|
im_array = r.plot() |
|
im_array = im_array.astype(np.uint8) |
|
im_array = cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB) |
|
|
|
if len(r.obb.cls) == 0: |
|
return None, "No objects detected." |
|
|
|
for cls in r.obb.cls: |
|
class_name = r.names[cls.item()] |
|
class_counts[class_name] = class_counts.get(class_name, 0) + 1 |
|
|
|
for cls, count in class_counts.items(): |
|
class_counts_str += f"\n{cls}: {count}" |
|
|
|
return im_array, class_counts_str |
|
|
|
iface = gr.Interface( |
|
fn=process_image, |
|
inputs=[ |
|
gr.Image(), |
|
gr.Radio(model_names, label="Choose model", value=model_names[0]), |
|
gr.Slider(minimum=0.2, maximum=1.0, step=0.1, label="Confidence Threshold", value=0.6) |
|
], |
|
outputs=["image", gr.Textbox(label="More info")], |
|
title="YOLOv8-obb aerial detection", |
|
description='''YOLOv8-obb trained on DOTAv1.5''', |
|
examples=example_list |
|
) |
|
|
|
iface.launch() |