VlaTal's picture
fixed colors
0d6882d
raw
history blame contribute delete
No virus
1.84 kB
import os
import gradio as gr
from ultralytics import YOLO
import numpy as np
import cv2
model_options = ["yolo-8n-shiprs.pt", "yolo-8s-shiprs.pt", "yolo-8m-shiprs.pt"]
model_names = ["Nano", "Small", "Medium"]
models = [YOLO(option) for option in model_options]
example_list = [["examples/" + example] for example in os.listdir("examples")]
def process_image(input_image, model_name, conf):
if input_image is None:
return None, "No image"
if model_name is None:
model_name = model_names[0]
if conf is None:
conf = 0.6
input_image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
model_index = model_names.index(model_name)
model = models[model_index]
results = model.predict(input_image, conf=conf)
class_counts = {}
class_counts_str = "Class Counts:\n"
for r in results:
im_array = r.plot()
im_array = im_array.astype(np.uint8)
im_array = cv2.cvtColor(im_array, cv2.COLOR_BGR2RGB)
if len(r.obb.cls) == 0: # If no objects are detected
return None, "No objects detected."
for cls in r.obb.cls:
class_name = r.names[cls.item()]
class_counts[class_name] = class_counts.get(class_name, 0) + 1
for cls, count in class_counts.items():
class_counts_str += f"\n{cls}: {count}"
return im_array, class_counts_str
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(),
gr.Radio(model_names, label="Choose model", value=model_names[0]),
gr.Slider(minimum=0.2, maximum=1.0, step=0.1, label="Confidence Threshold", value=0.6)
],
outputs=["image", gr.Textbox(label="More info")],
title="YOLOv8-obb aerial detection",
description='''YOLOv8-obb trained on DOTAv1.5''',
examples=example_list
)
iface.launch()